Skip to content

Commit 16924cd

Browse files
Stop inheriting tests (again) (#42247)
* Stop inheriting tests! * Just use a del instead * fixup * Stop using del! * make fixup
1 parent 266d3b0 commit 16924cd

File tree

2 files changed

+247
-11
lines changed

2 files changed

+247
-11
lines changed

tests/models/cohere2/test_modeling_cohere2.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@
3434
torch_device,
3535
)
3636

37-
from ...models.cohere.test_modeling_cohere import CohereModelTest, CohereModelTester
38-
from ...test_configuration_common import ConfigTester
37+
from ...models.cohere.test_modeling_cohere import CohereModelTester
3938

4039

4140
if is_torch_available():
@@ -46,6 +45,11 @@
4645
Cohere2Model,
4746
)
4847

48+
from ...generation.test_utils import GenerationTesterMixin
49+
from ...test_configuration_common import ConfigTester
50+
from ...test_modeling_common import ModelTesterMixin
51+
from ...test_pipeline_mixin import PipelineTesterMixin
52+
4953

5054
class Cohere2ModelTester(CohereModelTester):
5155
config_class = Cohere2Config
@@ -55,7 +59,7 @@ class Cohere2ModelTester(CohereModelTester):
5559

5660

5761
@require_torch
58-
class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
62+
class Cohere2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
5963
all_model_classes = (Cohere2Model, Cohere2ForCausalLM) if is_torch_available() else ()
6064
pipeline_model_mapping = (
6165
{
@@ -67,10 +71,21 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
6771
)
6872
_is_stateful = True
6973

74+
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
75+
# This is because we are hitting edge cases with the causal_mask buffer
76+
model_split_percents = [0.5, 0.7, 0.8]
77+
7078
def setUp(self):
7179
self.model_tester = Cohere2ModelTester(self)
7280
self.config_tester = ConfigTester(self, config_class=Cohere2Config, hidden_size=37)
7381

82+
def test_config(self):
83+
self.config_tester.run_common_tests()
84+
85+
def test_model(self):
86+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
87+
self.model_tester.create_and_check_model(*config_and_inputs)
88+
7489

7590
@slow
7691
@require_read_token
@@ -269,6 +284,3 @@ def test_generation_beyond_sliding_window(self, attn_implementation: str):
269284
output_text = tokenizer.batch_decode(out)
270285

271286
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
272-
273-
274-
del CohereModelTest, CohereModelTester # So the parent tests don't run in this file too

tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py

Lines changed: 229 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,32 @@
1414
# limitations under the License.
1515
"""Testing suite for the PyTorch GraniteMoeHybrid model."""
1616

17+
import inspect
18+
import tempfile
1719
import unittest
1820

1921
import pytest
22+
from pytest import mark
2023

2124
from transformers import (
2225
AutoTokenizer,
26+
DataCollatorWithFlattening,
2327
GraniteMoeHybridConfig,
2428
is_torch_available,
2529
)
2630
from transformers.testing_utils import (
31+
require_flash_attn,
2732
require_torch,
2833
require_torch_gpu,
2934
slow,
3035
torch_device,
3136
)
3237

3338
from ...generation.test_utils import GenerationTesterMixin
34-
from ...models.bamba.test_modeling_bamba import BambaModelTest, BambaModelTester
39+
from ...models.bamba.test_modeling_bamba import BambaModelTester
40+
from ...test_configuration_common import ConfigTester
41+
from ...test_modeling_common import ModelTesterMixin
42+
from ...test_pipeline_mixin import PipelineTesterMixin
3543

3644

3745
if is_torch_available():
@@ -77,7 +85,7 @@ def get_config(self):
7785

7886

7987
@require_torch
80-
class GraniteMoeHybridModelTest(BambaModelTest, GenerationTesterMixin, unittest.TestCase):
88+
class GraniteMoeHybridModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
8189
model_tester_class = GraniteMoeHybridModelTester
8290
all_model_classes = (
8391
(
@@ -96,6 +104,225 @@ class GraniteMoeHybridModelTest(BambaModelTest, GenerationTesterMixin, unittest.
96104
else {}
97105
)
98106

107+
# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
108+
# This is because we are hitting edge cases with the causal_mask buffer
109+
model_split_percents = [0.5, 0.7, 0.8]
110+
111+
def _check_caches_are_equal(
112+
self, cache1: HybridMambaAttentionDynamicCache, cache2: HybridMambaAttentionDynamicCache
113+
):
114+
if not isinstance(cache1, HybridMambaAttentionDynamicCache) or not isinstance(
115+
cache2, HybridMambaAttentionDynamicCache
116+
):
117+
raise ValueError("The wrong cache is being used!")
118+
119+
if not len(cache1) == len(cache2):
120+
raise ValueError("Both caches do not have the same number of layers.")
121+
122+
num_layers = len(cache1)
123+
for idx in range(num_layers):
124+
torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx])
125+
torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx])
126+
torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx])
127+
torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx])
128+
129+
def setUp(self):
130+
self.model_tester = self.model_tester_class(self)
131+
self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class, hidden_size=64)
132+
133+
def test_config(self):
134+
self.config_tester.run_common_tests()
135+
136+
def test_model(self):
137+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
138+
self.model_tester.create_and_check_model(*config_and_inputs)
139+
140+
def test_for_causal_lm(self):
141+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
142+
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
143+
144+
def test_decoder_model_past_with_large_inputs(self):
145+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
146+
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
147+
148+
def test_attention_outputs(self):
149+
r"""
150+
Overriding the test_attention_outputs test as the Bamba model outputs attention only for its attention layers
151+
"""
152+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
153+
config.return_dict = True
154+
155+
seq_len = getattr(self.model_tester, "seq_length", None)
156+
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
157+
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
158+
159+
expected_num_attentions = self.model_tester.num_hidden_layers - len(self.model_tester.attn_layer_indices)
160+
161+
for model_class in self.all_model_classes:
162+
inputs_dict["output_attentions"] = True
163+
inputs_dict["output_hidden_states"] = False
164+
config.return_dict = True
165+
model = model_class._from_config(config, attn_implementation="eager")
166+
config = model.config
167+
model.to(torch_device)
168+
model.eval()
169+
170+
with torch.no_grad():
171+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
172+
attentions = outputs.attentions
173+
self.assertEqual(len(attentions), expected_num_attentions)
174+
175+
# check that output_attentions also work using config
176+
del inputs_dict["output_attentions"]
177+
config.output_attentions = True
178+
model = model_class(config)
179+
model.to(torch_device)
180+
model.eval()
181+
with torch.no_grad():
182+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
183+
attentions = outputs.attentions
184+
self.assertEqual(len(attentions), expected_num_attentions)
185+
186+
self.assertListEqual(
187+
list(attentions[0].shape[-3:]),
188+
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
189+
)
190+
out_len = len(outputs)
191+
192+
# Check attention is always last and order is fine
193+
inputs_dict["output_attentions"] = True
194+
inputs_dict["output_hidden_states"] = True
195+
model = model_class(config)
196+
model.to(torch_device)
197+
model.eval()
198+
with torch.no_grad():
199+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
200+
201+
added_hidden_states = 1
202+
self.assertEqual(out_len + added_hidden_states, len(outputs))
203+
204+
self_attentions = outputs.attentions
205+
206+
self.assertEqual(len(self_attentions), expected_num_attentions)
207+
self.assertListEqual(
208+
list(self_attentions[0].shape[-3:]),
209+
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
210+
)
211+
212+
def test_batching_equivalence(self):
213+
# need to disable the tril input mask
214+
orig = self.model_tester.use_input_mask
215+
self.model_tester.use_input_mask = False
216+
super().test_batching_equivalence()
217+
self.model_tester.use_input_mask = orig
218+
219+
@pytest.mark.generate
220+
def test_left_padding_compatibility(self):
221+
# TODO: document why a random attention mask causes this test to fail, but a full mask doesn't
222+
unpadded_custom_inputs = {"attention_mask": None}
223+
super().test_left_padding_compatibility(unpadded_custom_inputs=unpadded_custom_inputs)
224+
225+
@unittest.skip(
226+
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
227+
)
228+
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
229+
pass
230+
231+
@unittest.skip(
232+
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
233+
)
234+
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
235+
pass
236+
237+
@require_flash_attn
238+
@require_torch_gpu
239+
@mark.flash_attn_test
240+
@slow
241+
@unittest.skip(
242+
"NotImplementedError: seq_idx support requires fast path support. Please install mamba_ssm and causal_conv1d"
243+
)
244+
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_idx_and_fa_kwargs(self):
245+
if not self.has_attentions:
246+
self.skipTest(reason="Model architecture does not support attentions")
247+
248+
max_new_tokens = 30
249+
250+
for model_class in self.all_generative_model_classes:
251+
if not model_class._supports_flash_attn:
252+
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
253+
254+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
255+
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
256+
self.skipTest("Model dummy inputs should contain padding in their attention mask")
257+
258+
dummy_input = inputs_dict[model_class.main_input_name]
259+
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
260+
dummy_input = dummy_input.to(torch.float16)
261+
262+
# make sure that all models have enough positions for generation
263+
if hasattr(config, "max_position_embeddings"):
264+
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
265+
266+
model = model_class(config)
267+
if "position_ids" not in inspect.signature(model.forward).parameters:
268+
self.skipTest("Model does not support position_ids")
269+
270+
with tempfile.TemporaryDirectory() as tmpdirname:
271+
model.save_pretrained(tmpdirname)
272+
273+
# ensure left padding, to adapt for some models
274+
if 0 in inputs_dict["attention_mask"][:, -1]:
275+
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
276+
dummy_attention_mask = inputs_dict["attention_mask"]
277+
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
278+
# Ensure inputs_dict also has labels in it, as their presence/absence can induce
279+
# dtype conversions. This also lets us compare losses.
280+
labels = inputs_dict["input_ids"].clone()
281+
# Mask padding tokens
282+
labels[~dummy_attention_mask.bool()] = -100
283+
# Also need to mask the first non-trivial token to match the padding-free batch.
284+
first_nonneg_idx = (labels >= 0).int().argmax(dim=1)
285+
labels[torch.arange(labels.size(0), device=labels.device), first_nonneg_idx] = -100
286+
inputs_dict["labels"] = labels
287+
288+
model = (
289+
model_class.from_pretrained(
290+
tmpdirname,
291+
dtype=torch.float16,
292+
attn_implementation="flash_attention_2",
293+
)
294+
.to(torch_device)
295+
.eval()
296+
)
297+
298+
# flatten
299+
features = [
300+
{"input_ids": i[a.bool()].tolist()}
301+
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
302+
]
303+
304+
# add position_ids + fa_kwargs + seq_idx
305+
data_collator = DataCollatorWithFlattening(
306+
return_tensors="pt", return_seq_idx=True, return_flash_attn_kwargs=True
307+
)
308+
batch = data_collator(features)
309+
batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()}
310+
311+
res_padded = model(**inputs_dict)
312+
res_padfree = model(**batch_accelerator)
313+
314+
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
315+
logits_padfree = res_padfree.logits[0]
316+
317+
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
318+
# acceptable numerical instability
319+
tol = torch.finfo(torch.float16).eps
320+
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
321+
322+
loss_padded = res_padded.loss
323+
loss_padfree = res_padfree.loss
324+
torch.testing.assert_close(loss_padded, loss_padfree)
325+
99326
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config):
100327
self.assertIsInstance(past_key_values, HybridMambaAttentionDynamicCache)
101328

@@ -178,6 +405,3 @@ def test_model_generation(self):
178405
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
179406

180407
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
181-
182-
183-
del BambaModelTest, BambaModelTester # So the parent tests don't run in this file too

0 commit comments

Comments
 (0)