2424from typing import Optional , Union
2525
2626import torch
27- import torch .nn .functional as F
2827from torch import nn
2928
3029from ...activations import ACT2FN
@@ -164,47 +163,99 @@ class AfmoeTokenChoiceRouter(nn.Module):
164163 """
165164 Token-choice top-K router for MoE routing.
166165
167- This router assigns each token to the top-K experts based on learned routing scores.
168- It supports both sigmoid and softmax scoring functions.
166+ This router assigns each token to the top-K experts based on sigmoid scores, matching the released checkpoints.
169167 """
170168
171169 def __init__ (self , config ):
172170 super ().__init__ ()
173171 self .config = config
174172 self .top_k = config .num_experts_per_tok
175173 self .num_experts = config .num_experts
176- self .score_func = config .score_func
177174 self .route_norm = config .route_norm
178175 self .route_scale = config .route_scale
179176 self .gate = nn .Linear (config .hidden_size , config .num_experts , bias = False )
180177
181- def forward (self , hidden_states , expert_bias : torch .Tensor | None ):
178+ def forward (self , hidden_states : torch .Tensor , expert_bias : torch .Tensor | None = None ):
179+ # Keep expert_bias argument for checkpoint/backwards compatibility (it is always zero in released models).
180+ del expert_bias
182181 _ , _ , hidden_dim = hidden_states .shape
183182 hidden_states = hidden_states .view (- 1 , hidden_dim )
184183
185- scores = self .gate (hidden_states )
184+ scores = torch .sigmoid (self .gate (hidden_states ).to (torch .float32 ))
185+ top_scores , selected_experts = torch .topk (scores , k = self .top_k , dim = 1 )
186186
187- # Apply scoring function in float32 for stability
188- if self .score_func == "sigmoid" :
189- scores = torch .sigmoid (scores .to (torch .float32 ))
190- else :
191- scores = F .softmax (scores .to (torch .float32 ), dim = - 1 )
192-
193- if expert_bias is not None :
194- _ , selected_experts = torch .topk (scores + expert_bias , k = self .top_k , dim = 1 )
195- top_scores = scores .gather (dim = 1 , index = selected_experts )
196- else :
197- top_scores , selected_experts = torch .topk (scores , k = self .top_k , dim = 1 )
198-
199- # Normalize weights if using sigmoid
200- if self .score_func == "sigmoid" and self .route_norm :
187+ if self .route_norm :
201188 denominator = top_scores .sum (dim = - 1 , keepdim = True ) + 1e-20
202189 top_scores = top_scores / denominator
203190
204191 top_scores = top_scores * self .route_scale
205192 return top_scores , selected_experts
206193
207194
195+ class AfmoeExperts (nn .ModuleList ):
196+ """
197+ Container holding the routed experts.
198+
199+ This mirrors the Experts pattern used across other MoE models to ease checkpoint conversion.
200+ """
201+
202+ _checkpoint_conversion_mapping = {"experts" : "experts" }
203+
204+ def __init__ (self , config : AfmoeConfig ):
205+ super ().__init__ ()
206+ self .top_k = config .num_experts_per_tok
207+ self .num_experts = config .num_experts
208+ for _ in range (self .num_experts ):
209+ self .append (AfmoeMLP (config , intermediate_size = config .moe_intermediate_size ))
210+
211+ def forward (
212+ self , hidden_states : torch .Tensor , selected_experts : torch .Tensor , routing_weights : torch .Tensor
213+ ) -> torch .Tensor :
214+ """
215+ Args:
216+ hidden_states: (batch, seq, hidden)
217+ selected_experts: (batch, seq, top_k)
218+ routing_weights: (batch, seq, top_k)
219+ """
220+ batch_size , seq_len , hidden_dim = hidden_states .shape
221+ if seq_len == 0 :
222+ return hidden_states .new_zeros (batch_size , 0 , hidden_dim )
223+ hidden_states_flat = hidden_states .view (- 1 , hidden_dim )
224+ top_k = selected_experts .shape [- 1 ]
225+
226+ # Map every token routing decision to a unique position so we can process expert by expert.
227+ token_indices = torch .arange (
228+ hidden_states_flat .shape [0 ], device = hidden_states .device , dtype = torch .long
229+ ).repeat_interleave (top_k )
230+ expert_indices = selected_experts .reshape (- 1 )
231+ routing_weights = routing_weights .reshape (- 1 )
232+
233+ sorting = torch .argsort (expert_indices , stable = True )
234+ token_indices = token_indices [sorting ]
235+ expert_indices = expert_indices [sorting ]
236+ routing_weights = routing_weights [sorting ]
237+
238+ dispatched_tokens = hidden_states_flat .index_select (0 , token_indices )
239+ expert_outputs = torch .zeros_like (dispatched_tokens )
240+
241+ unique_experts , counts = torch .unique_consecutive (expert_indices , return_counts = True )
242+ start = 0
243+ for expert_id , count in zip (unique_experts .tolist (), counts .tolist ()):
244+ if count == 0 :
245+ continue
246+ end = start + count
247+ expert_input = dispatched_tokens [start :end ]
248+ expert_output = self [expert_id ](expert_input )
249+ expert_outputs [start :end ] = expert_output
250+ start = end
251+
252+ weighted_outputs = (expert_outputs .to (torch .float32 ) * routing_weights .unsqueeze (- 1 )).to (hidden_states .dtype )
253+ aggregated = torch .zeros_like (hidden_states_flat )
254+ scatter_indices = token_indices .unsqueeze (- 1 ).expand_as (weighted_outputs )
255+ aggregated .scatter_add_ (0 , scatter_indices , weighted_outputs )
256+ return aggregated .view (batch_size , seq_len , hidden_dim )
257+
258+
208259class AfmoeMoE (nn .Module ):
209260 """
210261 Mixture of Experts (MoE) module for AFMoE.
@@ -221,9 +272,7 @@ def __init__(self, config):
221272 self .shared_experts = None
222273 if config .num_shared_experts > 0 :
223274 self .shared_experts = AfmoeMLP (config , config .moe_intermediate_size * config .num_shared_experts )
224- self .experts = nn .ModuleList (
225- [AfmoeMLP (config , intermediate_size = config .moe_intermediate_size ) for _ in range (config .num_experts )]
226- )
275+ self .experts = AfmoeExperts (config )
227276 self .expert_bias = nn .Parameter (torch .zeros (config .num_experts , dtype = torch .float32 ), requires_grad = False )
228277
229278 def forward (self , hidden_states ):
@@ -232,37 +281,17 @@ def forward(self, hidden_states):
232281
233282 # Get routing decisions
234283 top_scores , selected_experts = self .router (hidden_states , self .expert_bias )
284+ top_scores = top_scores .view (batch_size , seq_len , self .config .num_experts_per_tok )
285+ selected_experts = selected_experts .view (batch_size , seq_len , self .config .num_experts_per_tok )
235286
236287 # Process through shared experts
237288 if self .shared_experts is not None :
238- shared_output = self .shared_experts (hidden_states_flat )
289+ shared_output = self .shared_experts (hidden_states_flat ). view ( batch_size , seq_len , hidden_dim )
239290 else :
240- shared_output = torch .zeros_like (hidden_states_flat )
241-
242- # Reorder tokens by expert for efficient processing
243- token_indices_sorted = torch .argsort (selected_experts .view (- 1 ), stable = True )
244- top_scores_sorted = top_scores .view (- 1 )[token_indices_sorted ]
245- token_to_expert = selected_experts .view (- 1 )[token_indices_sorted ]
246- token_indices_sorted = token_indices_sorted // self .config .num_experts_per_tok
291+ shared_output = hidden_states .new_zeros (batch_size , seq_len , hidden_dim )
247292
248- # Gather input tokens
249- token_indices_expanded = token_indices_sorted .unsqueeze (- 1 ).expand (- 1 , hidden_dim )
250- routed_input = torch .gather (hidden_states_flat , dim = 0 , index = token_indices_expanded )
251-
252- routed_output = torch .zeros_like (routed_input )
253- for expert_id in range (self .config .num_experts ):
254- mask = token_to_expert == expert_id
255- if mask .any ():
256- expert_input = routed_input [mask ]
257- expert_out = self .experts [expert_id ](expert_input )
258- routed_output [mask ] = expert_out
259-
260- routed_output = (routed_output .to (torch .float32 ) * top_scores_sorted .unsqueeze (- 1 )).to (hidden_states .dtype )
261-
262- # Scatter back to original positions
263- output = shared_output .scatter_add (dim = 0 , index = token_indices_expanded , src = routed_output )
264-
265- return output .view (batch_size , seq_len , hidden_dim )
293+ routed_output = self .experts (hidden_states , selected_experts , top_scores )
294+ return shared_output + routed_output
266295
267296
268297def rotate_half (x : torch .Tensor ) -> torch .Tensor :
@@ -318,32 +347,40 @@ class AfmoeAttention(nn.Module):
318347 Multi-headed attention module with optional sliding window and gating.
319348
320349 This attention mechanism supports both full attention and sliding window attention,
321- and includes Q/K normalization and gating of the output.
350+ and includes Q/K normalization and gating of the output. It inherits from [`LlamaAttention`] to minimize the amount
351+ of custom logic we need to maintain.
322352 """
323353
324354 def __init__ (self , config : AfmoeConfig , layer_idx : int ):
325355 super ().__init__ ()
326356 self .config = config
327357 self .layer_idx = layer_idx
328358 self .head_dim = getattr (config , "head_dim" , config .hidden_size // config .num_attention_heads )
329- self .num_heads = config .num_attention_heads
330- self .num_key_value_heads = config .num_key_value_heads
331- self .num_key_value_groups = self .num_heads // self .num_key_value_heads
332-
359+ self .num_key_value_groups = config .num_attention_heads // config .num_key_value_heads
333360 self .scaling = self .head_dim ** - 0.5
334361 self .attention_dropout = config .attention_dropout
362+ self .is_causal = True
363+
364+ self .q_proj = nn .Linear (
365+ config .hidden_size , config .num_attention_heads * self .head_dim , bias = config .attention_bias
366+ )
367+ self .k_proj = nn .Linear (
368+ config .hidden_size , config .num_key_value_heads * self .head_dim , bias = config .attention_bias
369+ )
370+ self .v_proj = nn .Linear (
371+ config .hidden_size , config .num_key_value_heads * self .head_dim , bias = config .attention_bias
372+ )
373+ self .o_proj = nn .Linear (
374+ config .num_attention_heads * self .head_dim , config .hidden_size , bias = config .attention_bias
375+ )
376+ # Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim
377+ # We only add AFMoE-specific attributes
335378 self .is_local_attention = config .layer_types [layer_idx ] == "sliding_attention"
336379 self .sliding_window = config .sliding_window if self .is_local_attention else None
337380
338- self .q_proj = nn .Linear (config .hidden_size , self .num_heads * self .head_dim , bias = False )
339- self .k_proj = nn .Linear (config .hidden_size , self .num_key_value_heads * self .head_dim , bias = False )
340- self .v_proj = nn .Linear (config .hidden_size , self .num_key_value_heads * self .head_dim , bias = False )
341- self .o_proj = nn .Linear (self .num_heads * self .head_dim , config .hidden_size , bias = False )
342-
343381 self .q_norm = AfmoeRMSNorm (self .head_dim , eps = config .rms_norm_eps )
344382 self .k_norm = AfmoeRMSNorm (self .head_dim , eps = config .rms_norm_eps )
345-
346- self .gate_proj = nn .Linear (config .hidden_size , self .num_heads * self .head_dim , bias = False )
383+ self .gate_proj = nn .Linear (config .hidden_size , config .num_attention_heads * self .head_dim , bias = False )
347384
348385 def forward (
349386 self ,
@@ -362,11 +399,8 @@ def forward(
362399 value_states = self .v_proj (hidden_states ).view (hidden_shape )
363400 gate_states = self .gate_proj (hidden_states )
364401
365- query_states = self .q_norm (query_states )
366- key_states = self .k_norm (key_states )
367-
368- query_states = query_states .transpose (1 , 2 )
369- key_states = key_states .transpose (1 , 2 )
402+ query_states = self .q_norm (query_states ).transpose (1 , 2 )
403+ key_states = self .k_norm (key_states ).transpose (1 , 2 )
370404 value_states = value_states .transpose (1 , 2 )
371405
372406 if self .is_local_attention :
@@ -394,7 +428,7 @@ def forward(
394428 )
395429
396430 output = output .view (* input_shape , - 1 ).contiguous ()
397- output = output * F .sigmoid (gate_states )
431+ output = output * torch .sigmoid (gate_states )
398432 return self .o_proj (output )
399433
400434
@@ -505,15 +539,15 @@ class AfmoePreTrainedModel(PreTrainedModel):
505539 def _init_weights (self , module ):
506540 """Initialize the weights"""
507541 if isinstance (module , nn .Linear ):
508- module .weight .data . normal_ (mean = 0.0 , std = self .config .initializer_range )
542+ module .weight .normal_ (mean = 0.0 , std = self .config .initializer_range )
509543 if module .bias is not None :
510- module .bias .data . zero_ ()
544+ module .bias .zero_ ()
511545 elif isinstance (module , nn .Embedding ):
512- module .weight .data . normal_ (mean = 0.0 , std = self .config .initializer_range )
546+ module .weight .normal_ (mean = 0.0 , std = self .config .initializer_range )
513547 if module .padding_idx is not None :
514- module .weight . data [module .padding_idx ].zero_ ()
548+ module .weight [module .padding_idx ].zero_ ()
515549 elif isinstance (module , AfmoeRMSNorm ):
516- module .weight .data . fill_ (1.0 )
550+ module .weight .fill_ (1.0 )
517551
518552
519553@auto_docstring
@@ -591,11 +625,7 @@ def forward(
591625 "sliding_attention" : create_sliding_window_causal_mask (** mask_kwargs ),
592626 }
593627
594- hidden_states = inputs_embeds
595-
596- # Apply muP input scaling if enabled
597- if self .config .mup_enabled :
598- hidden_states = hidden_states * (self .config .hidden_size ** 0.5 )
628+ hidden_states = inputs_embeds * (self .config .hidden_size ** 0.5 )
599629
600630 position_embeddings = self .rotary_emb (hidden_states , position_ids )
601631
@@ -633,7 +663,7 @@ class AfmoeForCausalLM(AfmoePreTrainedModel, GenerationMixin):
633663 [`~PreTrainedModel.from_pretrained`] method to load the model weights.
634664 """
635665
636- _tied_weights_keys = [ "lm_head.weight" ]
666+ _tied_weights_keys = { "lm_head.weight" : "model.embed_tokens.weight" }
637667 _tp_plan = {"lm_head" : "colwise_rep" }
638668 _pp_plan = {"lm_head" : (["hidden_states" ], ["logits" ])}
639669
0 commit comments