Skip to content

Commit 3bc726b

Browse files
authored
[gemma3] fix bidirectional image mask (#39396)
* fix gemma3 mask * make compile happy, and use only torch ops * no full attention between images * update tests * fix tests * add a fast test
1 parent fbeaf96 commit 3bc726b

File tree

4 files changed

+112
-21
lines changed

4 files changed

+112
-21
lines changed

src/transformers/generation/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,8 +646,8 @@ def prepare_inputs_for_generation(
646646

647647
# If it's not defined, it means the model uses the new general mask API
648648
if causal_mask_creation_function is None: # can't be found
649-
token_type_ids = getattr(model_input, "token_type_ids", None)
650-
position_ids = getattr(model_input, position_ids_key, None)
649+
token_type_ids = model_inputs.get("token_type_ids", None)
650+
position_ids = model_inputs.get(position_ids_key, None)
651651
# Some models may overwrite the general one
652652
causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate)
653653
attention_mask = causal_mask_creation_function(

src/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,11 @@ def forward(self, vision_outputs: torch.Tensor):
737737
return projected_vision_outputs.type_as(vision_outputs)
738738

739739

740-
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
740+
def token_type_ids_mask_function(
741+
token_type_ids: Optional[torch.Tensor],
742+
image_group_ids: Optional[torch.Tensor],
743+
tokens_per_image: int,
744+
) -> Optional[Callable]:
741745
"""
742746
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
743747
not start and end indices.
@@ -747,10 +751,18 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_
747751
return None
748752

749753
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
750-
# If the difference is less than image size, both are part of the same image block
751-
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
752754
# If it's 1 for both query and key/value, we are in an image block
753-
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)
755+
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
756+
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
757+
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
758+
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
759+
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
760+
761+
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
762+
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
763+
764+
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
765+
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
754766

755767
# This is bidirectional attention whenever we are dealing with image tokens
756768
return is_image_block & same_image_block
@@ -915,8 +927,15 @@ def forward(
915927
}
916928
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
917929
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
930+
931+
# First find where a new image block starts: 1 if image and previous not image
932+
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
933+
is_image = (token_type_ids == 1).to(cache_position.device)
934+
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
935+
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
936+
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
918937
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
919-
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
938+
token_type_ids.to(cache_position.device), image_group_ids, self.config.mm_tokens_per_image
920939
)
921940

922941
# Create the masks
@@ -1181,8 +1200,15 @@ def create_masks_for_generate(
11811200
# Add the token type ids mask for generate as well
11821201
if token_type_ids is not None and input_embeds.shape[1] != 1:
11831202
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
1203+
1204+
# First find where a new image block starts: 1 if image and previous not image
1205+
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
1206+
is_image = (token_type_ids == 1).to(cache_position.device)
1207+
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
1208+
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
1209+
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
11841210
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
1185-
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
1211+
token_type_ids.to(cache_position.device), image_group_ids, config.mm_tokens_per_image
11861212
)
11871213

11881214
return create_masks_for_generate(**mask_kwargs)

src/transformers/models/gemma3/modular_gemma3.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,11 @@ def forward(self, vision_outputs: torch.Tensor):
716716
return projected_vision_outputs.type_as(vision_outputs)
717717

718718

719-
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_per_image: int) -> Optional[Callable]:
719+
def token_type_ids_mask_function(
720+
token_type_ids: Optional[torch.Tensor],
721+
image_group_ids: Optional[torch.Tensor],
722+
tokens_per_image: int,
723+
) -> Optional[Callable]:
720724
"""
721725
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
722726
not start and end indices.
@@ -726,10 +730,18 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor], tokens_
726730
return None
727731

728732
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
729-
# If the difference is less than image size, both are part of the same image block
730-
same_image_block = torch.abs(kv_idx - q_idx) <= tokens_per_image
731733
# If it's 1 for both query and key/value, we are in an image block
732-
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids[batch_idx, kv_idx] == 1)
734+
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
735+
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
736+
safe_idx = torch.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
737+
token_type_ids_at_kv_idx = token_type_ids[batch_idx, safe_idx]
738+
token_type_ids_at_kv_idx = torch.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)
739+
740+
image_group_ids_at_kv_idx = image_group_ids[batch_idx, safe_idx]
741+
image_group_ids_at_kv_idx = torch.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)
742+
743+
is_image_block = (token_type_ids[batch_idx, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
744+
same_image_block = image_group_ids[batch_idx, q_idx] == image_group_ids_at_kv_idx
733745

734746
# This is bidirectional attention whenever we are dealing with image tokens
735747
return is_image_block & same_image_block
@@ -840,8 +852,15 @@ def forward(
840852
}
841853
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
842854
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
855+
856+
# First find where a new image block starts: 1 if image and previous not image
857+
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
858+
is_image = (token_type_ids == 1).to(cache_position.device)
859+
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
860+
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
861+
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
843862
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
844-
token_type_ids.to(cache_position.device), self.config.mm_tokens_per_image
863+
token_type_ids.to(cache_position.device), image_group_ids, self.config.mm_tokens_per_image
845864
)
846865

847866
# Create the masks
@@ -1062,8 +1081,15 @@ def create_masks_for_generate(
10621081
# Add the token type ids mask for generate as well
10631082
if token_type_ids is not None and input_embeds.shape[1] != 1:
10641083
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or`
1084+
1085+
# First find where a new image block starts: 1 if image and previous not image
1086+
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
1087+
is_image = (token_type_ids == 1).to(cache_position.device)
1088+
new_image_start = is_image & ~nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
1089+
image_group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
1090+
image_group_ids = torch.where(is_image, image_group_ids, torch.full_like(token_type_ids, -1))
10651091
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
1066-
token_type_ids.to(cache_position.device), config.mm_tokens_per_image
1092+
token_type_ids.to(cache_position.device), image_group_ids, config.mm_tokens_per_image
10671093
)
10681094

10691095
return create_masks_for_generate(**mask_kwargs)

tests/models/gemma3/test_modeling_gemma3.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,45 @@ def setUp(self):
270270
self.model_tester = Gemma3Vision2TextModelTester(self)
271271
self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37)
272272

273+
def test_bidirectional_image_attention(self):
274+
"""
275+
Tests that each image can attend to itself bidirectionally. However an image
276+
cannot attend to future images, even within the same batch.
277+
"""
278+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
279+
config._attn_implementation = "eager"
280+
model = Gemma3Model(config).to(torch_device)
281+
282+
# First let's pass inputs without change which is one image per text and manipulate
283+
# `token_type_ids` to make sure bidirectional mask is applied where it has to be
284+
inputs_dict["token_type_ids"] = torch.zeros_like(inputs_dict["token_type_ids"])
285+
inputs_dict["token_type_ids"][:, :4] = 1 # unmask first 4 tokens
286+
with torch.no_grad():
287+
out = model(**inputs_dict, output_attentions=True)
288+
# We expect a non-causal mask on first 4 tokens, thus no zeros
289+
for attention in out.attentions:
290+
self.assertTrue((attention[..., :4, :4] != 0).all().item())
291+
292+
# Now when removing `token_type_ids`, we will get simple causal mask
293+
inputs_dict["token_type_ids"][:, :4] = 0 # mask back first 4 tokens
294+
with torch.no_grad():
295+
out = model(**inputs_dict, output_attentions=True)
296+
# We expect a causal mask on first 4 tokens, thus no zeros
297+
for attention in out.attentions:
298+
self.assertFalse((attention[..., :4, :4] != 0).all().item())
299+
300+
# Let's add two "images" per text, first one spanning 4 tokens and last one 3 tokens
301+
inputs_dict["token_type_ids"][:, :4] = 1
302+
inputs_dict["token_type_ids"][:, 7:10] = 1
303+
with torch.no_grad():
304+
out = model(**inputs_dict, output_attentions=True)
305+
for attention in out.attentions:
306+
self.assertTrue((attention[..., :4, :4] != 0).all().item())
307+
self.assertTrue((attention[..., 7:10, 7:10] != 0).all().item())
308+
309+
# We expect a non-causal mask only within same image and no looking ahead to the future
310+
self.assertTrue((attention[..., :4, 7:10] == 0).all().item())
311+
273312
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
274313
def test_training_gradient_checkpointing(self):
275314
pass
@@ -413,7 +452,7 @@ def test_model_4b_bf16(self):
413452
EXPECTED_TEXTS = Expectations(
414453
{
415454
("xpu", 3): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water in the background. It looks like a lovely,'],
416-
("cuda", 8): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant coastline in the background. It looks'],
455+
("cuda", 8): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like'],
417456
("rocm", (9, 5)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant coastline in the background. It looks'],
418457
}
419458
) # fmt: skip
@@ -463,8 +502,8 @@ def test_model_4b_batch(self):
463502
],
464503
("cuda", 8):
465504
[
466-
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks',
467-
'user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. They depict very different scenes. \n\n* **Image 1** shows a cow standing on a beach'
505+
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like',
506+
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a brown"
468507
],
469508
("rocm", (9, 5)):
470509
[
@@ -508,7 +547,7 @@ def test_model_4b_crops(self):
508547
{
509548
("xpu", 3): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.'],
510549
("cuda", 7): [],
511-
("cuda", 8): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.'],
550+
("cuda", 8): ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the"],
512551
}
513552
) # fmt: skip
514553
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
@@ -565,8 +604,8 @@ def test_model_4b_batch_crops(self):
565604
],
566605
("cuda", 7): [],
567606
("cuda", 8): [
568-
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.',
569-
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a',
607+
"user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the",
608+
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a'
570609
],
571610
("rocm", (9, 5)) : [
572611
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.',
@@ -610,7 +649,7 @@ def test_model_4b_multiimage(self):
610649
{
611650
("xpu", 3): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image!\n\nHere's a description of the scene:\n\n* **Chinese Arch"],
612651
("cuda", 7): [],
613-
("cuda", 8): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Main Features:**\n\n* **Chinese Archway:** The most prominent"],
652+
("cuda", 8): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"],
614653
}
615654
) # fmt: skip
616655
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()

0 commit comments

Comments
 (0)