- 
                Notifications
    
You must be signed in to change notification settings  - Fork 1.3k
 
[FEAT] Add Grouped Topk Routing to LLaMAMoE (Based on DeepseekV3MoE) #2134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Open
      
      
            ysjprojects
  wants to merge
  5
  commits into
  Lightning-AI:main
  
    
      
        
          
  
    
      Choose a base branch
      
     
    
      
        
      
      
        
          
          
        
        
          
            
              
              
              
  
           
        
        
          
            
              
              
           
        
       
     
  
        
          
            
          
            
          
        
       
    
      
from
ysjprojects:deepseekv3moe
  
      
      
   
  
    
  
  
  
 
  
      
    base: main
Could not load branches
            
              
  
    Branch not found: {{ refName }}
  
            
                
      Loading
              
            Could not load tags
            
            
              Nothing to show
            
              
  
            
                
      Loading
              
            Are you sure you want to change the base?
            Some commits from the old base branch may be removed from the timeline,
            and old review comments may become outdated.
          
          
  
     Open
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            5 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      936af4a
              
                deepseekv3moe
              
              
                ysjprojects a39c070
              
                [pre-commit.ci] auto fixes from pre-commit.com hooks
              
              
                pre-commit-ci[bot] 67fd5eb
              
                Update tests/test_deepseek_moe.py
              
              
                ysjprojects 76bf888
              
                Merge branch 'main' into deepseekv3moe
              
              
                ysjprojects 0d3dac3
              
                Merge branch 'main' into deepseekv3moe
              
              
                ysjprojects File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. | ||
| 
     | 
||
| import pytest | ||
| import torch | ||
| from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM | ||
| 
     | 
||
| from litgpt import Config | ||
| from litgpt.model import GPT, LLaMAMLP | ||
| 
     | 
||
| 
     | 
||
| @torch.inference_mode() | ||
| @pytest.mark.parametrize("batch_size", (1, 2)) | ||
| @pytest.mark.parametrize("seq_len", (8, 16)) | ||
| @pytest.mark.parametrize("device", [torch.device("cpu")]) | ||
| def test_deepseek_moe_litgpt_vs_hf(batch_size, seq_len, device): | ||
| """Test MOE litgpt vs hf""" | ||
| config_litgpt = Config( | ||
| padded_vocab_size=10000, | ||
| n_layer=2, | ||
| vocab_size=10000, | ||
| n_embd=64, | ||
| n_head=4, | ||
| n_query_groups=4, | ||
| head_size=16, | ||
| norm_eps=1e-6, | ||
| bias=False, | ||
| latent_attention={ | ||
| "q_lora_rank": 32, | ||
| "kv_lora_rank": 16, | ||
| "qk_rope_head_dim": 8, | ||
| "qk_nope_head_dim": 8, | ||
| "v_head_dim": 16, | ||
| }, | ||
| n_expert=16, | ||
| n_shared_expert=1, | ||
| n_expert_per_token=2, | ||
| n_expert_groups=4, | ||
| n_topk_groups=2, | ||
| n_topk_scores_per_group=2, # Note: Deepseek hardcodes this to `2` | ||
| first_k_dense_replace=1, | ||
| routed_scaling_factor=2.5, | ||
| norm_topk_prob=True, | ||
| moe_intermediate_size=20, | ||
| mlp_class_name="LLaMAMoE", | ||
| ) | ||
| 
     | 
||
| config_hf = DeepseekV3Config( | ||
| padded_vocab_size=10000, | ||
| num_hidden_layers=2, | ||
| vocab_size=10000, | ||
| hidden_size=64, | ||
| num_attention_heads=4, | ||
| num_key_value_heads=4, | ||
| q_lora_rank=32, | ||
| kv_lora_rank=16, | ||
| qk_rope_head_dim=8, | ||
| qk_nope_head_dim=8, | ||
| v_head_dim=16, | ||
| rope_interleave=False, | ||
| first_k_dense_replace=1, | ||
| routed_scaling_factor=2.5, | ||
| norm_topk_prob=True, | ||
| n_routed_experts=config_litgpt.n_expert, | ||
| n_shared_experts=config_litgpt.n_shared_expert, | ||
| num_experts_per_tok=config_litgpt.n_expert_per_token, | ||
| n_group=config_litgpt.n_expert_groups, | ||
| topk_group=config_litgpt.n_topk_groups, | ||
| moe_intermediate_size=config_litgpt.moe_intermediate_size, | ||
| ) | ||
| 
     | 
||
| model_litgpt = GPT(config_litgpt).to(device) | ||
| model_litgpt.apply(model_litgpt._init_weights) | ||
| 
     | 
||
| mlp_litgpt = model_litgpt.transformer.h[0].mlp | ||
| assert isinstance(mlp_litgpt, LLaMAMLP) # Test first_k_dense_replace (k=1) | ||
| 
     | 
||
| moe_litgpt = model_litgpt.transformer.h[1].mlp | ||
| model_hf = DeepseekV3ForCausalLM(config_hf).to(device) | ||
| moe_hf = model_hf.model.layers[1].mlp | ||
| 
     | 
||
| moe_litgpt.eval() | ||
| moe_hf.eval() | ||
| 
     | 
||
| sync_weights(moe_litgpt, moe_hf) | ||
| 
     | 
||
| hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device) | ||
| 
     | 
||
| output_litgpt = moe_litgpt(hidden_states) | ||
| output_hf = moe_hf(hidden_states) | ||
| 
     | 
||
| assert torch.allclose(output_litgpt, output_hf, atol=1e-5) | ||
| 
     | 
||
| 
     | 
||
| def sync_weights(litgpt_model, hf_model): | ||
| print("Synchronizing MoE weights...") | ||
| 
     | 
||
| with torch.no_grad(): | ||
| if hasattr(litgpt_model, "gate"): | ||
| if hasattr(litgpt_model.gate, "weight"): | ||
| hf_model.gate.weight.copy_(litgpt_model.gate.weight) | ||
| if hasattr(litgpt_model.gate, "e_score_correction_bias"): | ||
| hf_model.gate.e_score_correction_bias.copy_(litgpt_model.gate.e_score_correction_bias) | ||
| 
     | 
||
| for i, (litgpt_expert, hf_expert) in enumerate(zip(litgpt_model.experts, hf_model.experts)): | ||
| hf_expert.gate_proj.weight.copy_(litgpt_expert.fc_1.weight) | ||
| hf_expert.up_proj.weight.copy_(litgpt_expert.fc_2.weight) | ||
| hf_expert.down_proj.weight.copy_(litgpt_expert.proj.weight) | ||
| 
     | 
||
| if hasattr(litgpt_model, "shared_experts") and hasattr(hf_model, "shared_experts"): | ||
| hf_model.shared_experts.gate_proj.weight.copy_(litgpt_model.shared_experts.fc_1.weight) | ||
| hf_model.shared_experts.up_proj.weight.copy_(litgpt_model.shared_experts.fc_2.weight) | ||
| hf_model.shared_experts.down_proj.weight.copy_(litgpt_model.shared_experts.proj.weight) | ||
| 
     | 
||
| print("MoE weight synchronization complete.") | 
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.