Skip to content

Commit 95ae07d

Browse files
Isotr0pyArthurZucker
authored andcommitted
Fix broken image inference for Fuyu model (#39915)
* fix fuyu Signed-off-by: Isotr0py <2037008807@qq.com> * oops Signed-off-by: Isotr0py <2037008807@qq.com> * run test on GPU Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> * clean unused Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> * revert Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> * add fuyu multimodal test Signed-off-by: Isotr0py <2037008807@qq.com> * fix Signed-off-by: Isotr0py <2037008807@qq.com> --------- Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent 0d9032a commit 95ae07d

File tree

2 files changed

+49
-11
lines changed

2 files changed

+49
-11
lines changed

src/transformers/models/fuyu/modeling_fuyu.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def forward(
225225
if image_patches is not None:
226226
patch_embeddings = self.get_image_features(image_patches)
227227
patch_embeddings = torch.cat(patch_embeddings, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
228-
special_image_mask = self.get_placeholder_tokens(
228+
special_image_mask = self.get_placeholder_mask(
229229
input_ids, inputs_embeds=inputs_embeds, image_features=patch_embeddings
230230
)
231231
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings)
@@ -379,6 +379,7 @@ def prepare_inputs_for_generation(
379379
inputs_embeds=None,
380380
image_patches=None,
381381
image_patches_indices=None,
382+
cache_position=None,
382383
**kwargs,
383384
):
384385
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -390,10 +391,12 @@ def prepare_inputs_for_generation(
390391
inputs_embeds=inputs_embeds,
391392
image_patches=image_patches,
392393
image_patches_indices=image_patches_indices,
394+
cache_position=cache_position,
393395
**kwargs,
394396
)
395397

396-
if past_key_values is not None:
398+
if cache_position[0] != 0:
399+
# set image_patches and image_patches_indices to `None` for decoding stage
397400
model_inputs["image_patches_indices"] = None
398401
model_inputs["image_patches"] = None
399402

tests/models/fuyu/test_modeling_fuyu.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,21 @@
1313
# limitations under the License.
1414
"""Testing suite for the PyTorch Fuyu model."""
1515

16+
import copy
1617
import io
1718
import unittest
1819

1920
import pytest
2021
import requests
22+
import torch
2123
from parameterized import parameterized
2224

2325
from transformers import FuyuConfig, is_torch_available, is_vision_available
24-
from transformers.testing_utils import require_torch, require_torch_accelerator, slow
26+
from transformers.testing_utils import require_torch, require_torch_accelerator, slow, torch_device
2527
from transformers.utils import cached_property
2628

2729
from ...generation.test_utils import GenerationTesterMixin
28-
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
30+
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
2931
from ...test_pipeline_mixin import PipelineTesterMixin
3032

3133

@@ -47,6 +49,7 @@ def __init__(
4749
parent,
4850
batch_size=13,
4951
seq_length=7,
52+
num_image_tokens=2,
5053
image_size=30,
5154
patch_size=15,
5255
num_channels=3,
@@ -67,12 +70,14 @@ def __init__(
6770
initializer_range=0.02,
6871
num_labels=3,
6972
num_choices=4,
70-
pad_token_id=0,
73+
pad_token_id=10,
74+
image_token_id=1,
7175
scope=None,
7276
):
7377
self.parent = parent
7478
self.batch_size = batch_size
75-
self.seq_length = seq_length
79+
self.num_image_tokens = num_image_tokens
80+
self.seq_length = seq_length + num_image_tokens
7681
self.image_size = image_size
7782
self.patch_size = patch_size
7883
self.num_channels = num_channels
@@ -94,10 +99,15 @@ def __init__(
9499
self.num_labels = num_labels
95100
self.num_choices = num_choices
96101
self.pad_token_id = pad_token_id
102+
self.image_token_id = image_token_id
97103
self.scope = scope
98104

99105
def prepare_config_and_inputs(self):
106+
config = self.get_config()
107+
100108
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
109+
input_ids[input_ids == config.image_token_id] = self.pad_token_id
110+
input_ids[:, : self.num_image_tokens] = config.image_token_id
101111

102112
input_mask = None
103113
if self.use_input_mask:
@@ -109,8 +119,6 @@ def prepare_config_and_inputs(self):
109119
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
110120
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
111121

112-
config = self.get_config()
113-
114122
return config, input_ids, input_mask, sequence_labels, token_labels
115123

116124
def get_config(self):
@@ -128,6 +136,7 @@ def get_config(self):
128136
is_decoder=False,
129137
initializer_range=self.initializer_range,
130138
pad_token_id=self.pad_token_id,
139+
image_token_id=self.image_token_id,
131140
)
132141

133142
def prepare_config_and_inputs_for_common(self):
@@ -139,7 +148,10 @@ def prepare_config_and_inputs_for_common(self):
139148
sequence_labels,
140149
token_labels,
141150
) = config_and_inputs
142-
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
151+
image_patches = floats_tensor(
152+
[self.batch_size, self.num_image_tokens, config.num_channels * config.patch_size**2]
153+
)
154+
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask, "image_patches": image_patches}
143155
return config, inputs_dict
144156

145157

@@ -166,6 +178,27 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
166178
def setUp(self):
167179
self.model_tester = FuyuModelTester(self)
168180

181+
def test_mismatching_image_patches(self):
182+
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
183+
for model_class in self.all_model_classes:
184+
model = model_class(config).to(torch_device)
185+
curr_input_dict = copy.deepcopy(input_dict) # in=place modifications further
186+
187+
# two image token and two image
188+
_ = model(**curr_input_dict) # successful forward with no modifications
189+
190+
# remove one image but leave the image token in text
191+
input_ids = curr_input_dict["input_ids"]
192+
image_patches = curr_input_dict["image_patches"][1:, ...]
193+
with self.assertRaises(ValueError):
194+
_ = model(input_ids=input_ids, image_patches=image_patches)
195+
196+
# remove one image token from text
197+
input_ids = curr_input_dict["input_ids"][2:]
198+
image_patches = curr_input_dict["image_patches"]
199+
with self.assertRaises(ValueError):
200+
_ = model(input_ids=input_ids, image_patches=image_patches)
201+
169202
@unittest.skip(
170203
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
171204
)
@@ -232,7 +265,7 @@ def default_processor(self):
232265

233266
@cached_property
234267
def default_model(self):
235-
return FuyuForCausalLM.from_pretrained("adept/fuyu-8b")
268+
return FuyuForCausalLM.from_pretrained("adept/fuyu-8b", torch_dtype="float16", device_map=torch_device)
236269

237270
def test_greedy_generation(self):
238271
processor = self.default_processor
@@ -243,7 +276,9 @@ def test_greedy_generation(self):
243276

244277
text_prompt_coco_captioning = "Generate a coco-style caption.\n"
245278

246-
inputs = processor(images=image, text=text_prompt_coco_captioning, return_tensors="pt")
279+
inputs = processor(images=image, text=text_prompt_coco_captioning, return_tensors="pt").to(
280+
torch_device, torch.float16
281+
)
247282
generated_ids = model.generate(**inputs, max_new_tokens=10)
248283

249284
# take the last 8 tokens (in order to skip special \n\x04 characters) and decode them

0 commit comments

Comments
 (0)