Skip to content

Commit 649c7c8

Browse files
Lucas-rbntKumoLiuericspod
authored
Allow an arbitrary mask to be used in the self attention (#8235)
### Description The aim of this PR is to enable the use of an arbitrary mask in the self attention module, which is very useful in the case of missing data or masked modeling. Official torch implementations allow the use of an arbitrary mask, and in MONAI the use of a mask is also made possible with the `causal` argument. Here, it's just a generalization directly in the forward pass. In the `SABlock` and `TransformerBlock`, it is now possible to input a boolean mask of size `(BS, Seq_length)`. Only the columns of the masked token are set to `-inf` and not the rows, as is rarely the case in common implementations. Masked tokens don't contribute to the gradient anyway. In cases where causal attention is required, inputting a mask is not supported to avoid masks overlapping. I haven't implemented the addition mask to the attention matrix, which allows you to use values other than `-inf` in certain cases, as may be the case here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html If you think it's relevant, it could be added. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Lucas Robinet <robinet.lucas@iuct-oncopole.fr> Signed-off-by: Lucas Robinet <luca.robinet@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 3ee4cd2 commit 649c7c8

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

monai/networks/blocks/selfattention.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from __future__ import annotations
1313

14-
from typing import Tuple, Union
14+
from typing import Optional, Tuple, Union
1515

1616
import torch
1717
import torch.nn as nn
@@ -154,10 +154,12 @@ def __init__(
154154
)
155155
self.input_size = input_size
156156

157-
def forward(self, x):
157+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
158158
"""
159159
Args:
160160
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
161+
attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.
162+
B x (s_dim_1 * ... * s_dim_n). Defaults to None.
161163
162164
Return:
163165
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
@@ -176,7 +178,13 @@ def forward(self, x):
176178

177179
if self.use_flash_attention:
178180
x = F.scaled_dot_product_attention(
179-
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
181+
query=q,
182+
key=k,
183+
value=v,
184+
attn_mask=attn_mask,
185+
scale=self.scale,
186+
dropout_p=self.dropout_rate,
187+
is_causal=self.causal,
180188
)
181189
else:
182190
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
@@ -186,10 +194,16 @@ def forward(self, x):
186194
att_mat = self.rel_positional_embedding(x, att_mat, q)
187195

188196
if self.causal:
197+
if attn_mask is not None:
198+
raise ValueError("Causal attention does not support attention masks.")
189199
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))
190200

191-
att_mat = att_mat.softmax(dim=-1)
201+
if attn_mask is not None:
202+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
203+
attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)
204+
att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf"))
192205

206+
att_mat = att_mat.softmax(dim=-1)
193207
if self.save_attn:
194208
# no gradients and new tensor;
195209
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html

monai/networks/blocks/transformerblock.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,10 @@ def __init__(
9090
use_flash_attention=use_flash_attention,
9191
)
9292

93-
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
94-
x = x + self.attn(self.norm1(x))
93+
def forward(
94+
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None
95+
) -> torch.Tensor:
96+
x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
9597
if self.with_cross_attention:
9698
x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
9799
x = x + self.mlp(self.norm2(x))

tests/test_selfattention.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,24 @@ def test_causal(self):
122122
# check upper triangular part of the attention matrix is zero
123123
assert torch.triu(block.att_mat, diagonal=1).sum() == 0
124124

125+
def test_masked_selfattention(self):
126+
n = 64
127+
block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True)
128+
input_shape = (1, n, 128)
129+
# generate a mask randomly with zeros and ones of shape (1, n)
130+
mask = torch.randint(0, 2, (1, n)).bool()
131+
block(torch.randn(input_shape), attn_mask=mask)
132+
att_mat = block.att_mat.squeeze()
133+
# ensure all masked columns are zeros
134+
assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)]))
135+
136+
def test_causal_and_mask(self):
137+
with self.assertRaises(ValueError):
138+
block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64)
139+
inputs = torch.randn(2, 64, 128)
140+
mask = torch.randint(0, 2, (2, 64)).bool()
141+
block(inputs, attn_mask=mask)
142+
125143
@skipUnless(has_einops, "Requires einops")
126144
def test_access_attn_matrix(self):
127145
# input format

0 commit comments

Comments
 (0)