1313# limitations under the License.
1414"""Testing suite for the PyTorch Fuyu model."""
1515
16+ import copy
1617import io
1718import unittest
1819
1920import pytest
2021import requests
22+ import torch
2123from parameterized import parameterized
2224
2325from transformers import FuyuConfig , is_torch_available , is_vision_available
24- from transformers .testing_utils import require_torch , require_torch_accelerator , slow
26+ from transformers .testing_utils import require_torch , require_torch_accelerator , slow , torch_device
2527from transformers .utils import cached_property
2628
2729from ...generation .test_utils import GenerationTesterMixin
28- from ...test_modeling_common import ModelTesterMixin , ids_tensor , random_attention_mask
30+ from ...test_modeling_common import ModelTesterMixin , floats_tensor , ids_tensor , random_attention_mask
2931from ...test_pipeline_mixin import PipelineTesterMixin
3032
3133
@@ -47,6 +49,7 @@ def __init__(
4749 parent ,
4850 batch_size = 13 ,
4951 seq_length = 7 ,
52+ num_image_tokens = 2 ,
5053 image_size = 30 ,
5154 patch_size = 15 ,
5255 num_channels = 3 ,
@@ -67,12 +70,14 @@ def __init__(
6770 initializer_range = 0.02 ,
6871 num_labels = 3 ,
6972 num_choices = 4 ,
70- pad_token_id = 0 ,
73+ pad_token_id = 10 ,
74+ image_token_id = 1 ,
7175 scope = None ,
7276 ):
7377 self .parent = parent
7478 self .batch_size = batch_size
75- self .seq_length = seq_length
79+ self .num_image_tokens = num_image_tokens
80+ self .seq_length = seq_length + num_image_tokens
7681 self .image_size = image_size
7782 self .patch_size = patch_size
7883 self .num_channels = num_channels
@@ -94,10 +99,15 @@ def __init__(
9499 self .num_labels = num_labels
95100 self .num_choices = num_choices
96101 self .pad_token_id = pad_token_id
102+ self .image_token_id = image_token_id
97103 self .scope = scope
98104
99105 def prepare_config_and_inputs (self ):
106+ config = self .get_config ()
107+
100108 input_ids = ids_tensor ([self .batch_size , self .seq_length ], self .vocab_size )
109+ input_ids [input_ids == config .image_token_id ] = self .pad_token_id
110+ input_ids [:, : self .num_image_tokens ] = config .image_token_id
101111
102112 input_mask = None
103113 if self .use_input_mask :
@@ -109,8 +119,6 @@ def prepare_config_and_inputs(self):
109119 sequence_labels = ids_tensor ([self .batch_size ], self .type_sequence_label_size )
110120 token_labels = ids_tensor ([self .batch_size , self .seq_length ], self .num_labels )
111121
112- config = self .get_config ()
113-
114122 return config , input_ids , input_mask , sequence_labels , token_labels
115123
116124 def get_config (self ):
@@ -128,6 +136,7 @@ def get_config(self):
128136 is_decoder = False ,
129137 initializer_range = self .initializer_range ,
130138 pad_token_id = self .pad_token_id ,
139+ image_token_id = self .image_token_id ,
131140 )
132141
133142 def prepare_config_and_inputs_for_common (self ):
@@ -139,7 +148,10 @@ def prepare_config_and_inputs_for_common(self):
139148 sequence_labels ,
140149 token_labels ,
141150 ) = config_and_inputs
142- inputs_dict = {"input_ids" : input_ids , "attention_mask" : input_mask }
151+ image_patches = floats_tensor (
152+ [self .batch_size , self .num_image_tokens , config .num_channels * config .patch_size ** 2 ]
153+ )
154+ inputs_dict = {"input_ids" : input_ids , "attention_mask" : input_mask , "image_patches" : image_patches }
143155 return config , inputs_dict
144156
145157
@@ -166,6 +178,27 @@ class FuyuModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
166178 def setUp (self ):
167179 self .model_tester = FuyuModelTester (self )
168180
181+ def test_mismatching_image_patches (self ):
182+ config , input_dict = self .model_tester .prepare_config_and_inputs_for_common ()
183+ for model_class in self .all_model_classes :
184+ model = model_class (config ).to (torch_device )
185+ curr_input_dict = copy .deepcopy (input_dict ) # in=place modifications further
186+
187+ # two image token and two image
188+ _ = model (** curr_input_dict ) # successful forward with no modifications
189+
190+ # remove one image but leave the image token in text
191+ input_ids = curr_input_dict ["input_ids" ]
192+ image_patches = curr_input_dict ["image_patches" ][1 :, ...]
193+ with self .assertRaises (ValueError ):
194+ _ = model (input_ids = input_ids , image_patches = image_patches )
195+
196+ # remove one image token from text
197+ input_ids = curr_input_dict ["input_ids" ][2 :]
198+ image_patches = curr_input_dict ["image_patches" ]
199+ with self .assertRaises (ValueError ):
200+ _ = model (input_ids = input_ids , image_patches = image_patches )
201+
169202 @unittest .skip (
170203 reason = "This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
171204 )
@@ -232,7 +265,7 @@ def default_processor(self):
232265
233266 @cached_property
234267 def default_model (self ):
235- return FuyuForCausalLM .from_pretrained ("adept/fuyu-8b" )
268+ return FuyuForCausalLM .from_pretrained ("adept/fuyu-8b" , torch_dtype = "float16" , device_map = torch_device )
236269
237270 def test_greedy_generation (self ):
238271 processor = self .default_processor
@@ -243,7 +276,9 @@ def test_greedy_generation(self):
243276
244277 text_prompt_coco_captioning = "Generate a coco-style caption.\n "
245278
246- inputs = processor (images = image , text = text_prompt_coco_captioning , return_tensors = "pt" )
279+ inputs = processor (images = image , text = text_prompt_coco_captioning , return_tensors = "pt" ).to (
280+ torch_device , torch .float16
281+ )
247282 generated_ids = model .generate (** inputs , max_new_tokens = 10 )
248283
249284 # take the last 8 tokens (in order to skip special \n\x04 characters) and decode them
0 commit comments