Skip to content

Commit 694410d

Browse files
vasquLysandreJik
authored andcommitted
[Jetmoe] Fix RoPE (#40819)
* fix * remove prints * why was this there...
1 parent 240ebfe commit 694410d

File tree

3 files changed

+1
-11
lines changed

3 files changed

+1
-11
lines changed

src/transformers/models/jetmoe/configuration_jetmoe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class JetMoeConfig(PretrainedConfig):
9494

9595
model_type = "jetmoe"
9696
keys_to_ignore_at_inference = ["past_key_values"]
97+
attribute_map = {"head_dim": "kv_channels"}
9798

9899
def __init__(
99100
self,

src/transformers/models/jetmoe/modeling_jetmoe.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -936,15 +936,6 @@ def forward(
936936
if position_ids is None:
937937
position_ids = cache_position.unsqueeze(0)
938938

939-
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
940-
batch_size = inputs_embeds.shape[0]
941-
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
942-
if is_padding_right:
943-
raise ValueError(
944-
"You are attempting to perform batched generation with padding_side='right'"
945-
" this may lead to unexpected behaviour for Flash Attention version of JetMoe. Make sure to "
946-
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
947-
)
948939
causal_mask = self._update_causal_mask(
949940
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
950941
)

tests/models/jetmoe/test_modeling_jetmoe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,10 @@ def test_model_8b_batched_generation(self):
184184
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
185185
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
186186
input_ids = tokenizer(prompt, return_tensors="pt", padding=True).to(model.model.embed_tokens.weight.device)
187-
print(input_ids)
188187

189188
# greedy generation outputs
190189
generated_ids = model.generate(**input_ids, max_new_tokens=10, temperature=0)
191190
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
192-
print(text)
193191
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
194192

195193
del model

0 commit comments

Comments
 (0)