Skip to content

Commit bcd7b97

Browse files
committed
Add flex attention support to AFMoE model
1 parent 1314162 commit bcd7b97

File tree

3 files changed

+63
-1
lines changed

3 files changed

+63
-1
lines changed

docs/source/en/model_doc/afmoe.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
1313
rendered properly in your Markdown viewer.
1414
1515
-->
16-
*This model was released on {release_date} and added to Hugging Face Transformers on 2025-11-14.*
16+
*This model was released on {release_date} and added to Hugging Face Transformers on 2025-11-18.*
1717

1818
<div style="float: right;">
1919
<div class="flex flex-wrap space-x-1">

src/transformers/models/afmoe/modular_afmoe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ class AfmoePreTrainedModel(LlamaPreTrainedModel):
411411
]
412412
_supports_sdpa = True
413413
_supports_flash_attn_2 = True
414+
_supports_flex_attn = True
414415
_supports_attention_backend = True
415416
supports_gradient_checkpointing = True
416417

test_afmoe_load.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python3
2+
"""Quick test script to load AFMoE checkpoint weights."""
3+
4+
import torch
5+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
6+
7+
# Path to your checkpoint
8+
CHECKPOINT_PATH = "arcee-train/afmoe-nano-sft-v3-pocketRL-v0.1.4-2" # HuggingFace Hub checkpoint
9+
10+
def main():
11+
print("Loading AFMoE checkpoint...")
12+
13+
# Load config
14+
config = AutoConfig.from_pretrained(CHECKPOINT_PATH, trust_remote_code=False)
15+
print(f"Config loaded: {config.model_type}")
16+
print(f" - Hidden size: {config.hidden_size}")
17+
print(f" - Num layers: {config.num_hidden_layers}")
18+
print(f" - Num experts: {config.num_experts}")
19+
print(f" - Num shared experts: {config.num_shared_experts}")
20+
print(f" - Top-k: {config.num_experts_per_tok}")
21+
22+
# Load model
23+
model = AutoModelForCausalLM.from_pretrained(
24+
CHECKPOINT_PATH,
25+
config=config,
26+
torch_dtype=torch.bfloat16,
27+
device_map="auto",
28+
trust_remote_code=False,
29+
)
30+
print(f"\nModel loaded successfully!")
31+
print(f" - Model class: {model.__class__.__name__}")
32+
print(f" - Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B")
33+
34+
# Test forward pass
35+
print("\nTesting forward pass...")
36+
input_ids = torch.randint(0, config.vocab_size, (1, 10)).to(model.device)
37+
38+
with torch.no_grad():
39+
outputs = model(input_ids)
40+
print(f" - Output logits shape: {outputs.logits.shape}")
41+
print(f" - Output dtype: {outputs.logits.dtype}")
42+
43+
# Test generation (optional)
44+
try:
45+
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH)
46+
prompt = "Hello, how are you?"
47+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
48+
49+
print(f"\nTesting generation with prompt: '{prompt}'")
50+
with torch.no_grad():
51+
outputs = model.generate(**inputs, max_new_tokens=20, do_sample=False)
52+
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
53+
print(f"Generated: {generated}")
54+
except Exception as e:
55+
print(f"\nSkipping tokenizer test: {e}")
56+
57+
print("\n✅ All checks passed!")
58+
59+
if __name__ == "__main__":
60+
main()
61+

0 commit comments

Comments
 (0)