1414# limitations under the License.
1515"""Testing suite for the PyTorch GraniteMoeHybrid model."""
1616
17+ import inspect
18+ import tempfile
1719import unittest
1820
1921import pytest
22+ from pytest import mark
2023
2124from transformers import (
2225 AutoTokenizer ,
26+ DataCollatorWithFlattening ,
2327 GraniteMoeHybridConfig ,
2428 is_torch_available ,
2529)
2630from transformers .testing_utils import (
31+ require_flash_attn ,
2732 require_torch ,
2833 require_torch_gpu ,
2934 slow ,
3035 torch_device ,
3136)
3237
3338from ...generation .test_utils import GenerationTesterMixin
34- from ...models .bamba .test_modeling_bamba import BambaModelTest , BambaModelTester
39+ from ...models .bamba .test_modeling_bamba import BambaModelTester
40+ from ...test_configuration_common import ConfigTester
41+ from ...test_modeling_common import ModelTesterMixin
42+ from ...test_pipeline_mixin import PipelineTesterMixin
3543
3644
3745if is_torch_available ():
@@ -77,7 +85,7 @@ def get_config(self):
7785
7886
7987@require_torch
80- class GraniteMoeHybridModelTest (BambaModelTest , GenerationTesterMixin , unittest .TestCase ):
88+ class GraniteMoeHybridModelTest (ModelTesterMixin , GenerationTesterMixin , PipelineTesterMixin , unittest .TestCase ):
8189 model_tester_class = GraniteMoeHybridModelTester
8290 all_model_classes = (
8391 (
@@ -96,6 +104,225 @@ class GraniteMoeHybridModelTest(BambaModelTest, GenerationTesterMixin, unittest.
96104 else {}
97105 )
98106
107+ # Need to use `0.8` instead of `0.9` for `test_cpu_offload`
108+ # This is because we are hitting edge cases with the causal_mask buffer
109+ model_split_percents = [0.5 , 0.7 , 0.8 ]
110+
111+ def _check_caches_are_equal (
112+ self , cache1 : HybridMambaAttentionDynamicCache , cache2 : HybridMambaAttentionDynamicCache
113+ ):
114+ if not isinstance (cache1 , HybridMambaAttentionDynamicCache ) or not isinstance (
115+ cache2 , HybridMambaAttentionDynamicCache
116+ ):
117+ raise ValueError ("The wrong cache is being used!" )
118+
119+ if not len (cache1 ) == len (cache2 ):
120+ raise ValueError ("Both caches do not have the same number of layers." )
121+
122+ num_layers = len (cache1 )
123+ for idx in range (num_layers ):
124+ torch .testing .assert_close (cache1 .key_cache [idx ], cache2 .key_cache [idx ])
125+ torch .testing .assert_close (cache1 .value_cache [idx ], cache2 .value_cache [idx ])
126+ torch .testing .assert_close (cache1 .conv_states [idx ], cache2 .conv_states [idx ])
127+ torch .testing .assert_close (cache1 .ssm_states [idx ], cache2 .ssm_states [idx ])
128+
129+ def setUp (self ):
130+ self .model_tester = self .model_tester_class (self )
131+ self .config_tester = ConfigTester (self , config_class = self .model_tester .config_class , hidden_size = 64 )
132+
133+ def test_config (self ):
134+ self .config_tester .run_common_tests ()
135+
136+ def test_model (self ):
137+ config_and_inputs = self .model_tester .prepare_config_and_inputs ()
138+ self .model_tester .create_and_check_model (* config_and_inputs )
139+
140+ def test_for_causal_lm (self ):
141+ config_and_inputs = self .model_tester .prepare_config_and_inputs ()
142+ self .model_tester .create_and_check_for_causal_lm (* config_and_inputs )
143+
144+ def test_decoder_model_past_with_large_inputs (self ):
145+ config_and_inputs = self .model_tester .prepare_config_and_inputs ()
146+ self .model_tester .create_and_check_decoder_model_past_large_inputs (* config_and_inputs )
147+
148+ def test_attention_outputs (self ):
149+ r"""
150+ Overriding the test_attention_outputs test as the Bamba model outputs attention only for its attention layers
151+ """
152+ config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
153+ config .return_dict = True
154+
155+ seq_len = getattr (self .model_tester , "seq_length" , None )
156+ encoder_seq_length = getattr (self .model_tester , "encoder_seq_length" , seq_len )
157+ encoder_key_length = getattr (self .model_tester , "key_length" , encoder_seq_length )
158+
159+ expected_num_attentions = self .model_tester .num_hidden_layers - len (self .model_tester .attn_layer_indices )
160+
161+ for model_class in self .all_model_classes :
162+ inputs_dict ["output_attentions" ] = True
163+ inputs_dict ["output_hidden_states" ] = False
164+ config .return_dict = True
165+ model = model_class ._from_config (config , attn_implementation = "eager" )
166+ config = model .config
167+ model .to (torch_device )
168+ model .eval ()
169+
170+ with torch .no_grad ():
171+ outputs = model (** self ._prepare_for_class (inputs_dict , model_class ))
172+ attentions = outputs .attentions
173+ self .assertEqual (len (attentions ), expected_num_attentions )
174+
175+ # check that output_attentions also work using config
176+ del inputs_dict ["output_attentions" ]
177+ config .output_attentions = True
178+ model = model_class (config )
179+ model .to (torch_device )
180+ model .eval ()
181+ with torch .no_grad ():
182+ outputs = model (** self ._prepare_for_class (inputs_dict , model_class ))
183+ attentions = outputs .attentions
184+ self .assertEqual (len (attentions ), expected_num_attentions )
185+
186+ self .assertListEqual (
187+ list (attentions [0 ].shape [- 3 :]),
188+ [self .model_tester .num_attention_heads , encoder_seq_length , encoder_key_length ],
189+ )
190+ out_len = len (outputs )
191+
192+ # Check attention is always last and order is fine
193+ inputs_dict ["output_attentions" ] = True
194+ inputs_dict ["output_hidden_states" ] = True
195+ model = model_class (config )
196+ model .to (torch_device )
197+ model .eval ()
198+ with torch .no_grad ():
199+ outputs = model (** self ._prepare_for_class (inputs_dict , model_class ))
200+
201+ added_hidden_states = 1
202+ self .assertEqual (out_len + added_hidden_states , len (outputs ))
203+
204+ self_attentions = outputs .attentions
205+
206+ self .assertEqual (len (self_attentions ), expected_num_attentions )
207+ self .assertListEqual (
208+ list (self_attentions [0 ].shape [- 3 :]),
209+ [self .model_tester .num_attention_heads , encoder_seq_length , encoder_key_length ],
210+ )
211+
212+ def test_batching_equivalence (self ):
213+ # need to disable the tril input mask
214+ orig = self .model_tester .use_input_mask
215+ self .model_tester .use_input_mask = False
216+ super ().test_batching_equivalence ()
217+ self .model_tester .use_input_mask = orig
218+
219+ @pytest .mark .generate
220+ def test_left_padding_compatibility (self ):
221+ # TODO: document why a random attention mask causes this test to fail, but a full mask doesn't
222+ unpadded_custom_inputs = {"attention_mask" : None }
223+ super ().test_left_padding_compatibility (unpadded_custom_inputs = unpadded_custom_inputs )
224+
225+ @unittest .skip (
226+ "Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
227+ )
228+ def test_flash_attention_2_padding_matches_padding_free_with_position_ids (self ):
229+ pass
230+
231+ @unittest .skip (
232+ "Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
233+ )
234+ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs (self ):
235+ pass
236+
237+ @require_flash_attn
238+ @require_torch_gpu
239+ @mark .flash_attn_test
240+ @slow
241+ @unittest .skip (
242+ "NotImplementedError: seq_idx support requires fast path support. Please install mamba_ssm and causal_conv1d"
243+ )
244+ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_idx_and_fa_kwargs (self ):
245+ if not self .has_attentions :
246+ self .skipTest (reason = "Model architecture does not support attentions" )
247+
248+ max_new_tokens = 30
249+
250+ for model_class in self .all_generative_model_classes :
251+ if not model_class ._supports_flash_attn :
252+ self .skipTest (f"{ model_class .__name__ } does not support Flash Attention 2" )
253+
254+ config , inputs_dict = self .model_tester .prepare_config_and_inputs_for_common ()
255+ if 0 not in inputs_dict .get ("attention_mask" , []) or "attention_mask" not in inputs_dict :
256+ self .skipTest ("Model dummy inputs should contain padding in their attention mask" )
257+
258+ dummy_input = inputs_dict [model_class .main_input_name ]
259+ if dummy_input .dtype in [torch .float32 , torch .bfloat16 ]:
260+ dummy_input = dummy_input .to (torch .float16 )
261+
262+ # make sure that all models have enough positions for generation
263+ if hasattr (config , "max_position_embeddings" ):
264+ config .max_position_embeddings = max_new_tokens + dummy_input .shape [1 ] + 1
265+
266+ model = model_class (config )
267+ if "position_ids" not in inspect .signature (model .forward ).parameters :
268+ self .skipTest ("Model does not support position_ids" )
269+
270+ with tempfile .TemporaryDirectory () as tmpdirname :
271+ model .save_pretrained (tmpdirname )
272+
273+ # ensure left padding, to adapt for some models
274+ if 0 in inputs_dict ["attention_mask" ][:, - 1 ]:
275+ inputs_dict ["attention_mask" ] = inputs_dict ["attention_mask" ].flip (1 )
276+ dummy_attention_mask = inputs_dict ["attention_mask" ]
277+ inputs_dict ["input_ids" ][~ dummy_attention_mask .bool ()] = config .get_text_config ().pad_token_id
278+ # Ensure inputs_dict also has labels in it, as their presence/absence can induce
279+ # dtype conversions. This also lets us compare losses.
280+ labels = inputs_dict ["input_ids" ].clone ()
281+ # Mask padding tokens
282+ labels [~ dummy_attention_mask .bool ()] = - 100
283+ # Also need to mask the first non-trivial token to match the padding-free batch.
284+ first_nonneg_idx = (labels >= 0 ).int ().argmax (dim = 1 )
285+ labels [torch .arange (labels .size (0 ), device = labels .device ), first_nonneg_idx ] = - 100
286+ inputs_dict ["labels" ] = labels
287+
288+ model = (
289+ model_class .from_pretrained (
290+ tmpdirname ,
291+ dtype = torch .float16 ,
292+ attn_implementation = "flash_attention_2" ,
293+ )
294+ .to (torch_device )
295+ .eval ()
296+ )
297+
298+ # flatten
299+ features = [
300+ {"input_ids" : i [a .bool ()].tolist ()}
301+ for i , a in zip (inputs_dict ["input_ids" ], inputs_dict ["attention_mask" ])
302+ ]
303+
304+ # add position_ids + fa_kwargs + seq_idx
305+ data_collator = DataCollatorWithFlattening (
306+ return_tensors = "pt" , return_seq_idx = True , return_flash_attn_kwargs = True
307+ )
308+ batch = data_collator (features )
309+ batch_accelerator = {k : t .to (torch_device ) if torch .is_tensor (t ) else t for k , t in batch .items ()}
310+
311+ res_padded = model (** inputs_dict )
312+ res_padfree = model (** batch_accelerator )
313+
314+ logits_padded = res_padded .logits [inputs_dict ["attention_mask" ].bool ()]
315+ logits_padfree = res_padfree .logits [0 ]
316+
317+ torch .testing .assert_close (logits_padded .argmax (- 1 ), logits_padfree .argmax (- 1 ), rtol = 0 , atol = 0 )
318+ # acceptable numerical instability
319+ tol = torch .finfo (torch .float16 ).eps
320+ torch .testing .assert_close (logits_padded , logits_padfree , rtol = tol , atol = tol )
321+
322+ loss_padded = res_padded .loss
323+ loss_padfree = res_padfree .loss
324+ torch .testing .assert_close (loss_padded , loss_padfree )
325+
99326 def _check_past_key_values_for_generate (self , batch_size , past_key_values , seq_length , config ):
100327 self .assertIsInstance (past_key_values , HybridMambaAttentionDynamicCache )
101328
@@ -178,6 +405,3 @@ def test_model_generation(self):
178405 text = tokenizer .decode (generated_ids [0 ], skip_special_tokens = True )
179406
180407 self .assertEqual (EXPECTED_TEXT_COMPLETION , text )
181-
182-
183- del BambaModelTest , BambaModelTester # So the parent tests don't run in this file too
0 commit comments