Skip to content

Commit 1314162

Browse files
committed
Address review feedback for AFMoE implementation
1 parent 3a4280c commit 1314162

File tree

5 files changed

+207
-182
lines changed

5 files changed

+207
-182
lines changed

docs/source/en/model_doc/afmoe.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ specific language governing permissions and limitations under the License.
1313
rendered properly in your Markdown viewer.
1414
1515
-->
16+
*This model was released on {release_date} and added to Hugging Face Transformers on 2025-11-14.*
1617

1718
<div style="float: right;">
1819
<div class="flex flex-wrap space-x-1">

src/transformers/models/afmoe/configuration_afmoe.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,15 @@ class AfmoeConfig(PreTrainedConfig):
8888
Number of experts to route each token to. This is the top-k value for the token-choice routing.
8989
num_shared_experts (`int`, *optional*, defaults to 2):
9090
Number of shared experts that are always activated for all tokens.
91-
score_func (`str`, *optional*, defaults to `"sigmoid"`):
92-
The scoring function for routing decisions. Can be either "sigmoid" or "softmax".
9391
route_norm (`bool`, *optional*, defaults to `True`):
94-
Whether to normalize routing weights when using sigmoid scoring.
92+
Whether to normalize routing weights.
9593
route_scale (`float`, *optional*, defaults to 1.0):
9694
Scaling factor applied to routing weights.
9795
global_attn_every_n_layers (`int`, *optional*, defaults to 4):
9896
The frequency of full attention layers. Every Nth layer will use full attention, while others use sliding
9997
window attention.
10098
sliding_window (`int`, *optional*, defaults to 1024):
10199
Sliding window size for local attention layers.
102-
mup_enabled (`bool`, *optional*, defaults to `False`):
103-
Whether to enable muP (Maximal Update Parametrization) scaling for training stability.
104100
layer_types (`list[str]`, *optional*):
105101
A list that explicitly maps each layer index with its attention type. Each element should be either
106102
"sliding_attention" or "full_attention". If not provided, it will be automatically generated based on
@@ -155,12 +151,10 @@ def __init__(
155151
num_experts: Optional[int] = 64,
156152
num_experts_per_tok: Optional[int] = 6,
157153
num_shared_experts: Optional[int] = 2,
158-
score_func: Optional[str] = "sigmoid",
159154
route_norm: Optional[bool] = True,
160155
route_scale: Optional[float] = 1.0,
161156
global_attn_every_n_layers: Optional[int] = 4,
162157
sliding_window: Optional[int] = 1024,
163-
mup_enabled: Optional[bool] = False,
164158
layer_types: Optional[list] = None,
165159
attention_dropout: Optional[float] = 0.0,
166160
**kwargs,
@@ -185,9 +179,9 @@ def __init__(
185179
self.num_experts_per_tok = num_experts_per_tok
186180
self.num_experts = num_experts
187181
self.num_shared_experts = num_shared_experts
188-
self.score_func = score_func
189182
self.route_norm = route_norm
190183
self.route_scale = route_scale
184+
self.attention_bias = False
191185

192186
# Attention specific
193187
self.attention_dropout = attention_dropout
@@ -201,9 +195,6 @@ def __init__(
201195
]
202196
layer_type_validation(self.layer_types)
203197

204-
# muP specific
205-
self.mup_enabled = mup_enabled
206-
207198
if num_key_value_heads is None:
208199
num_key_value_heads = num_attention_heads
209200

src/transformers/models/afmoe/modeling_afmoe.py

Lines changed: 108 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from typing import Optional, Union
2525

2626
import torch
27-
import torch.nn.functional as F
2827
from torch import nn
2928

3029
from ...activations import ACT2FN
@@ -164,47 +163,99 @@ class AfmoeTokenChoiceRouter(nn.Module):
164163
"""
165164
Token-choice top-K router for MoE routing.
166165
167-
This router assigns each token to the top-K experts based on learned routing scores.
168-
It supports both sigmoid and softmax scoring functions.
166+
This router assigns each token to the top-K experts based on sigmoid scores, matching the released checkpoints.
169167
"""
170168

171169
def __init__(self, config):
172170
super().__init__()
173171
self.config = config
174172
self.top_k = config.num_experts_per_tok
175173
self.num_experts = config.num_experts
176-
self.score_func = config.score_func
177174
self.route_norm = config.route_norm
178175
self.route_scale = config.route_scale
179176
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
180177

181-
def forward(self, hidden_states, expert_bias: torch.Tensor | None):
178+
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor | None = None):
179+
# Keep expert_bias argument for checkpoint/backwards compatibility (it is always zero in released models).
180+
del expert_bias
182181
_, _, hidden_dim = hidden_states.shape
183182
hidden_states = hidden_states.view(-1, hidden_dim)
184183

185-
scores = self.gate(hidden_states)
184+
scores = torch.sigmoid(self.gate(hidden_states).to(torch.float32))
185+
top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)
186186

187-
# Apply scoring function in float32 for stability
188-
if self.score_func == "sigmoid":
189-
scores = torch.sigmoid(scores.to(torch.float32))
190-
else:
191-
scores = F.softmax(scores.to(torch.float32), dim=-1)
192-
193-
if expert_bias is not None:
194-
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
195-
top_scores = scores.gather(dim=1, index=selected_experts)
196-
else:
197-
top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)
198-
199-
# Normalize weights if using sigmoid
200-
if self.score_func == "sigmoid" and self.route_norm:
187+
if self.route_norm:
201188
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
202189
top_scores = top_scores / denominator
203190

204191
top_scores = top_scores * self.route_scale
205192
return top_scores, selected_experts
206193

207194

195+
class AfmoeExperts(nn.ModuleList):
196+
"""
197+
Container holding the routed experts.
198+
199+
This mirrors the Experts pattern used across other MoE models to ease checkpoint conversion.
200+
"""
201+
202+
_checkpoint_conversion_mapping = {"experts": "experts"}
203+
204+
def __init__(self, config: AfmoeConfig):
205+
super().__init__()
206+
self.top_k = config.num_experts_per_tok
207+
self.num_experts = config.num_experts
208+
for _ in range(self.num_experts):
209+
self.append(AfmoeMLP(config, intermediate_size=config.moe_intermediate_size))
210+
211+
def forward(
212+
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
213+
) -> torch.Tensor:
214+
"""
215+
Args:
216+
hidden_states: (batch, seq, hidden)
217+
selected_experts: (batch, seq, top_k)
218+
routing_weights: (batch, seq, top_k)
219+
"""
220+
batch_size, seq_len, hidden_dim = hidden_states.shape
221+
if seq_len == 0:
222+
return hidden_states.new_zeros(batch_size, 0, hidden_dim)
223+
hidden_states_flat = hidden_states.view(-1, hidden_dim)
224+
top_k = selected_experts.shape[-1]
225+
226+
# Map every token routing decision to a unique position so we can process expert by expert.
227+
token_indices = torch.arange(
228+
hidden_states_flat.shape[0], device=hidden_states.device, dtype=torch.long
229+
).repeat_interleave(top_k)
230+
expert_indices = selected_experts.reshape(-1)
231+
routing_weights = routing_weights.reshape(-1)
232+
233+
sorting = torch.argsort(expert_indices, stable=True)
234+
token_indices = token_indices[sorting]
235+
expert_indices = expert_indices[sorting]
236+
routing_weights = routing_weights[sorting]
237+
238+
dispatched_tokens = hidden_states_flat.index_select(0, token_indices)
239+
expert_outputs = torch.zeros_like(dispatched_tokens)
240+
241+
unique_experts, counts = torch.unique_consecutive(expert_indices, return_counts=True)
242+
start = 0
243+
for expert_id, count in zip(unique_experts.tolist(), counts.tolist()):
244+
if count == 0:
245+
continue
246+
end = start + count
247+
expert_input = dispatched_tokens[start:end]
248+
expert_output = self[expert_id](expert_input)
249+
expert_outputs[start:end] = expert_output
250+
start = end
251+
252+
weighted_outputs = (expert_outputs.to(torch.float32) * routing_weights.unsqueeze(-1)).to(hidden_states.dtype)
253+
aggregated = torch.zeros_like(hidden_states_flat)
254+
scatter_indices = token_indices.unsqueeze(-1).expand_as(weighted_outputs)
255+
aggregated.scatter_add_(0, scatter_indices, weighted_outputs)
256+
return aggregated.view(batch_size, seq_len, hidden_dim)
257+
258+
208259
class AfmoeMoE(nn.Module):
209260
"""
210261
Mixture of Experts (MoE) module for AFMoE.
@@ -221,9 +272,7 @@ def __init__(self, config):
221272
self.shared_experts = None
222273
if config.num_shared_experts > 0:
223274
self.shared_experts = AfmoeMLP(config, config.moe_intermediate_size * config.num_shared_experts)
224-
self.experts = nn.ModuleList(
225-
[AfmoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.num_experts)]
226-
)
275+
self.experts = AfmoeExperts(config)
227276
self.expert_bias = nn.Parameter(torch.zeros(config.num_experts, dtype=torch.float32), requires_grad=False)
228277

229278
def forward(self, hidden_states):
@@ -232,37 +281,17 @@ def forward(self, hidden_states):
232281

233282
# Get routing decisions
234283
top_scores, selected_experts = self.router(hidden_states, self.expert_bias)
284+
top_scores = top_scores.view(batch_size, seq_len, self.config.num_experts_per_tok)
285+
selected_experts = selected_experts.view(batch_size, seq_len, self.config.num_experts_per_tok)
235286

236287
# Process through shared experts
237288
if self.shared_experts is not None:
238-
shared_output = self.shared_experts(hidden_states_flat)
289+
shared_output = self.shared_experts(hidden_states_flat).view(batch_size, seq_len, hidden_dim)
239290
else:
240-
shared_output = torch.zeros_like(hidden_states_flat)
241-
242-
# Reorder tokens by expert for efficient processing
243-
token_indices_sorted = torch.argsort(selected_experts.view(-1), stable=True)
244-
top_scores_sorted = top_scores.view(-1)[token_indices_sorted]
245-
token_to_expert = selected_experts.view(-1)[token_indices_sorted]
246-
token_indices_sorted = token_indices_sorted // self.config.num_experts_per_tok
291+
shared_output = hidden_states.new_zeros(batch_size, seq_len, hidden_dim)
247292

248-
# Gather input tokens
249-
token_indices_expanded = token_indices_sorted.unsqueeze(-1).expand(-1, hidden_dim)
250-
routed_input = torch.gather(hidden_states_flat, dim=0, index=token_indices_expanded)
251-
252-
routed_output = torch.zeros_like(routed_input)
253-
for expert_id in range(self.config.num_experts):
254-
mask = token_to_expert == expert_id
255-
if mask.any():
256-
expert_input = routed_input[mask]
257-
expert_out = self.experts[expert_id](expert_input)
258-
routed_output[mask] = expert_out
259-
260-
routed_output = (routed_output.to(torch.float32) * top_scores_sorted.unsqueeze(-1)).to(hidden_states.dtype)
261-
262-
# Scatter back to original positions
263-
output = shared_output.scatter_add(dim=0, index=token_indices_expanded, src=routed_output)
264-
265-
return output.view(batch_size, seq_len, hidden_dim)
293+
routed_output = self.experts(hidden_states, selected_experts, top_scores)
294+
return shared_output + routed_output
266295

267296

268297
def rotate_half(x: torch.Tensor) -> torch.Tensor:
@@ -318,32 +347,40 @@ class AfmoeAttention(nn.Module):
318347
Multi-headed attention module with optional sliding window and gating.
319348
320349
This attention mechanism supports both full attention and sliding window attention,
321-
and includes Q/K normalization and gating of the output.
350+
and includes Q/K normalization and gating of the output. It inherits from [`LlamaAttention`] to minimize the amount
351+
of custom logic we need to maintain.
322352
"""
323353

324354
def __init__(self, config: AfmoeConfig, layer_idx: int):
325355
super().__init__()
326356
self.config = config
327357
self.layer_idx = layer_idx
328358
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
329-
self.num_heads = config.num_attention_heads
330-
self.num_key_value_heads = config.num_key_value_heads
331-
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
332-
359+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
333360
self.scaling = self.head_dim**-0.5
334361
self.attention_dropout = config.attention_dropout
362+
self.is_causal = True
363+
364+
self.q_proj = nn.Linear(
365+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
366+
)
367+
self.k_proj = nn.Linear(
368+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
369+
)
370+
self.v_proj = nn.Linear(
371+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
372+
)
373+
self.o_proj = nn.Linear(
374+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
375+
)
376+
# Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim
377+
# We only add AFMoE-specific attributes
335378
self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
336379
self.sliding_window = config.sliding_window if self.is_local_attention else None
337380

338-
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
339-
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
340-
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
341-
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
342-
343381
self.q_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
344382
self.k_norm = AfmoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
345-
346-
self.gate_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
383+
self.gate_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
347384

348385
def forward(
349386
self,
@@ -362,11 +399,8 @@ def forward(
362399
value_states = self.v_proj(hidden_states).view(hidden_shape)
363400
gate_states = self.gate_proj(hidden_states)
364401

365-
query_states = self.q_norm(query_states)
366-
key_states = self.k_norm(key_states)
367-
368-
query_states = query_states.transpose(1, 2)
369-
key_states = key_states.transpose(1, 2)
402+
query_states = self.q_norm(query_states).transpose(1, 2)
403+
key_states = self.k_norm(key_states).transpose(1, 2)
370404
value_states = value_states.transpose(1, 2)
371405

372406
if self.is_local_attention:
@@ -394,7 +428,7 @@ def forward(
394428
)
395429

396430
output = output.view(*input_shape, -1).contiguous()
397-
output = output * F.sigmoid(gate_states)
431+
output = output * torch.sigmoid(gate_states)
398432
return self.o_proj(output)
399433

400434

@@ -505,15 +539,15 @@ class AfmoePreTrainedModel(PreTrainedModel):
505539
def _init_weights(self, module):
506540
"""Initialize the weights"""
507541
if isinstance(module, nn.Linear):
508-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
542+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
509543
if module.bias is not None:
510-
module.bias.data.zero_()
544+
module.bias.zero_()
511545
elif isinstance(module, nn.Embedding):
512-
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
546+
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
513547
if module.padding_idx is not None:
514-
module.weight.data[module.padding_idx].zero_()
548+
module.weight[module.padding_idx].zero_()
515549
elif isinstance(module, AfmoeRMSNorm):
516-
module.weight.data.fill_(1.0)
550+
module.weight.fill_(1.0)
517551

518552

519553
@auto_docstring
@@ -591,11 +625,7 @@ def forward(
591625
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
592626
}
593627

594-
hidden_states = inputs_embeds
595-
596-
# Apply muP input scaling if enabled
597-
if self.config.mup_enabled:
598-
hidden_states = hidden_states * (self.config.hidden_size**0.5)
628+
hidden_states = inputs_embeds * (self.config.hidden_size**0.5)
599629

600630
position_embeddings = self.rotary_emb(hidden_states, position_ids)
601631

@@ -633,7 +663,7 @@ class AfmoeForCausalLM(AfmoePreTrainedModel, GenerationMixin):
633663
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
634664
"""
635665

636-
_tied_weights_keys = ["lm_head.weight"]
666+
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
637667
_tp_plan = {"lm_head": "colwise_rep"}
638668
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
639669

0 commit comments

Comments
 (0)