|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Quick test script to load AFMoE checkpoint weights.""" |
| 3 | + |
| 4 | +import torch |
| 5 | +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
| 6 | + |
| 7 | +# Path to your checkpoint |
| 8 | +CHECKPOINT_PATH = "arcee-train/afmoe-nano-sft-v3-pocketRL-v0.1.4-2" # HuggingFace Hub checkpoint |
| 9 | + |
| 10 | +def main(): |
| 11 | + print("Loading AFMoE checkpoint...") |
| 12 | + |
| 13 | + # Load config |
| 14 | + config = AutoConfig.from_pretrained(CHECKPOINT_PATH, trust_remote_code=False) |
| 15 | + print(f"Config loaded: {config.model_type}") |
| 16 | + print(f" - Hidden size: {config.hidden_size}") |
| 17 | + print(f" - Num layers: {config.num_hidden_layers}") |
| 18 | + print(f" - Num experts: {config.num_experts}") |
| 19 | + print(f" - Num shared experts: {config.num_shared_experts}") |
| 20 | + print(f" - Top-k: {config.num_experts_per_tok}") |
| 21 | + |
| 22 | + # Load model |
| 23 | + model = AutoModelForCausalLM.from_pretrained( |
| 24 | + CHECKPOINT_PATH, |
| 25 | + config=config, |
| 26 | + torch_dtype=torch.bfloat16, |
| 27 | + device_map="auto", |
| 28 | + trust_remote_code=False, |
| 29 | + ) |
| 30 | + print(f"\nModel loaded successfully!") |
| 31 | + print(f" - Model class: {model.__class__.__name__}") |
| 32 | + print(f" - Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B") |
| 33 | + |
| 34 | + # Test forward pass |
| 35 | + print("\nTesting forward pass...") |
| 36 | + input_ids = torch.randint(0, config.vocab_size, (1, 10)).to(model.device) |
| 37 | + |
| 38 | + with torch.no_grad(): |
| 39 | + outputs = model(input_ids) |
| 40 | + print(f" - Output logits shape: {outputs.logits.shape}") |
| 41 | + print(f" - Output dtype: {outputs.logits.dtype}") |
| 42 | + |
| 43 | + # Test generation (optional) |
| 44 | + try: |
| 45 | + tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH) |
| 46 | + prompt = "Hello, how are you?" |
| 47 | + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| 48 | + |
| 49 | + print(f"\nTesting generation with prompt: '{prompt}'") |
| 50 | + with torch.no_grad(): |
| 51 | + outputs = model.generate(**inputs, max_new_tokens=20, do_sample=False) |
| 52 | + generated = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| 53 | + print(f"Generated: {generated}") |
| 54 | + except Exception as e: |
| 55 | + print(f"\nSkipping tokenizer test: {e}") |
| 56 | + |
| 57 | + print("\n✅ All checks passed!") |
| 58 | + |
| 59 | +if __name__ == "__main__": |
| 60 | + main() |
| 61 | + |
0 commit comments