You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Allow an arbitrary mask to be used in the self attention (#8235)
### Description
The aim of this PR is to enable the use of an arbitrary mask in the self
attention module, which is very useful in the case of missing data or
masked modeling.
Official torch implementations allow the use of an arbitrary mask, and
in MONAI the use of a mask is also made possible with the `causal`
argument. Here, it's just a generalization directly in the forward pass.
In the `SABlock` and `TransformerBlock`, it is now possible to input a
boolean mask of size `(BS, Seq_length)`.
Only the columns of the masked token are set to `-inf` and not the rows,
as is rarely the case in common implementations. Masked tokens don't
contribute to the gradient anyway.
In cases where causal attention is required, inputting a mask is not
supported to avoid masks overlapping.
I haven't implemented the addition mask to the attention matrix, which
allows you to use values other than `-inf` in certain cases, as may be
the case here:
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
If you think it's relevant, it could be added.
### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [ ] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: Lucas Robinet <robinet.lucas@iuct-oncopole.fr>
Signed-off-by: Lucas Robinet <luca.robinet@gmail.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
0 commit comments