-
Notifications
You must be signed in to change notification settings - Fork 89
Open
Description
cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
Hello, I am confused of the implement of "attn_mask". I think this padding function only can mask the last row of "sim". Could you please explain it? Perhaps it's a very fool question. Thank you so much.
anilbatra2185
Metadata
Metadata
Assignees
Labels
No labels