Skip to content

Commit 30c3a20

Browse files
committed
Address PR review feedback for AFMoE model
1 parent 46ca8d5 commit 30c3a20

File tree

5 files changed

+34
-1006
lines changed

5 files changed

+34
-1006
lines changed

src/transformers/models/afmoe/configuration_afmoe.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Optional
1818

1919
from ...configuration_utils import PreTrainedConfig, layer_type_validation
20-
from ...modeling_rope_utils import rope_config_validation, standardize_rope_params
2120
from ...utils import logging
2221

2322

@@ -103,6 +102,9 @@ class AfmoeConfig(PreTrainedConfig):
103102
`global_attn_every_n_layers`.
104103
attention_dropout (`float`, *optional*, defaults to 0.0):
105104
The dropout ratio for the attention probabilities.
105+
mup_enabled (`bool`, *optional*, defaults to `False`):
106+
Whether to enable muP (Maximal Update Parametrization) input scaling. When enabled, input embeddings
107+
are scaled by `sqrt(hidden_size)`.
106108
107109
Example:
108110
```python
@@ -157,6 +159,7 @@ def __init__(
157159
sliding_window: Optional[int] = 1024,
158160
layer_types: Optional[list] = None,
159161
attention_dropout: Optional[float] = 0.0,
162+
mup_enabled: Optional[bool] = False,
160163
**kwargs,
161164
):
162165
self.vocab_size = vocab_size
@@ -187,6 +190,7 @@ def __init__(
187190
self.attention_dropout = attention_dropout
188191
self.global_attn_every_n_layers = global_attn_every_n_layers
189192
self.sliding_window = sliding_window
193+
self.mup_enabled = mup_enabled
190194
self.layer_types = layer_types
191195
if self.layer_types is None:
192196
self.layer_types = [
@@ -200,13 +204,6 @@ def __init__(
200204

201205
self.num_key_value_heads = num_key_value_heads
202206

203-
# Setup and validate rope configs
204-
self.rope_parameters = rope_scaling
205-
standardize_rope_params(self, rope_theta=rope_theta)
206-
if self.rope_scaling is not None and "type" in self.rope_scaling:
207-
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
208-
rope_config_validation(self)
209-
210207
super().__init__(
211208
tie_word_embeddings=tie_word_embeddings,
212209
**kwargs,

src/transformers/models/afmoe/modeling_afmoe.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -175,18 +175,16 @@ def __init__(self, config):
175175
self.route_scale = config.route_scale
176176
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
177177

178-
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor | None = None):
178+
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor):
179179
_, _, hidden_dim = hidden_states.shape
180180
hidden_states = hidden_states.view(-1, hidden_dim)
181181

182182
scores = torch.sigmoid(self.gate(hidden_states).to(torch.float32))
183183

184-
if expert_bias is not None:
185-
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
186-
top_scores = scores.gather(dim=1, index=selected_experts)
187-
else:
188-
top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)
184+
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
185+
top_scores = scores.gather(dim=1, index=selected_experts)
189186

187+
# Normalize routing weights (default: True for sigmoid scoring)
190188
if self.route_norm:
191189
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
192190
top_scores = top_scores / denominator
@@ -202,8 +200,6 @@ class AfmoeExperts(nn.ModuleList):
202200
This mirrors the Experts pattern used across other MoE models to ease checkpoint conversion.
203201
"""
204202

205-
_checkpoint_conversion_mapping = {"experts": "experts"}
206-
207203
def __init__(self, config: AfmoeConfig):
208204
super().__init__()
209205
self.top_k = config.num_experts_per_tok
@@ -376,6 +372,7 @@ def __init__(self, config: AfmoeConfig, layer_idx: int):
376372
self.o_proj = nn.Linear(
377373
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
378374
)
375+
self.rotary_fn = apply_rotary_pos_emb
379376
# Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim
380377
# We only add AFMoE-specific attributes
381378
self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
@@ -542,15 +539,15 @@ class AfmoePreTrainedModel(PreTrainedModel):
542539
def _init_weights(self, module):
543540
"""Initialize the weights"""
544541
if isinstance(module, nn.Linear):
545-
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
542+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
546543
if module.bias is not None:
547-
module.bias.zero_()
544+
nn.init.zeros_(module.bias)
548545
elif isinstance(module, nn.Embedding):
549-
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
546+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
550547
if module.padding_idx is not None:
551-
module.weight[module.padding_idx].zero_()
548+
nn.init.zeros_(module.weight[module.padding_idx])
552549
elif isinstance(module, AfmoeRMSNorm):
553-
module.weight.fill_(1.0)
550+
nn.init.ones_(module.weight)
554551

555552

556553
@auto_docstring
@@ -628,7 +625,11 @@ def forward(
628625
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
629626
}
630627

631-
hidden_states = inputs_embeds * (self.config.hidden_size**0.5)
628+
hidden_states = inputs_embeds
629+
630+
# Apply muP input scaling if enabled
631+
if self.config.mup_enabled:
632+
hidden_states = hidden_states * (self.config.hidden_size**0.5)
632633

633634
position_embeddings = self.rotary_emb(hidden_states, position_ids)
634635

src/transformers/models/afmoe/modular_afmoe.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,16 @@ def __init__(self, config):
132132
self.route_scale = config.route_scale
133133
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
134134

135-
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor | None = None):
135+
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor):
136136
_, _, hidden_dim = hidden_states.shape
137137
hidden_states = hidden_states.view(-1, hidden_dim)
138138

139139
scores = torch.sigmoid(self.gate(hidden_states).to(torch.float32))
140140

141-
if expert_bias is not None:
142-
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
143-
top_scores = scores.gather(dim=1, index=selected_experts)
144-
else:
145-
top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)
141+
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
142+
top_scores = scores.gather(dim=1, index=selected_experts)
146143

144+
# Normalize routing weights (default: True for sigmoid scoring)
147145
if self.route_norm:
148146
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
149147
top_scores = top_scores / denominator
@@ -159,8 +157,6 @@ class AfmoeExperts(nn.ModuleList):
159157
This mirrors the Experts pattern used across other MoE models to ease checkpoint conversion.
160158
"""
161159

162-
_checkpoint_conversion_mapping = {"experts": "experts"}
163-
164160
def __init__(self, config: AfmoeConfig):
165161
super().__init__()
166162
self.top_k = config.num_experts_per_tok
@@ -421,15 +417,15 @@ class AfmoePreTrainedModel(LlamaPreTrainedModel):
421417
def _init_weights(self, module):
422418
"""Initialize the weights"""
423419
if isinstance(module, nn.Linear):
424-
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
420+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
425421
if module.bias is not None:
426-
module.bias.zero_()
422+
nn.init.zeros_(module.bias)
427423
elif isinstance(module, nn.Embedding):
428-
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
424+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
429425
if module.padding_idx is not None:
430-
module.weight[module.padding_idx].zero_()
426+
nn.init.zeros_(module.weight[module.padding_idx])
431427
elif isinstance(module, AfmoeRMSNorm):
432-
module.weight.fill_(1.0)
428+
nn.init.ones_(module.weight)
433429

434430

435431
@auto_docstring
@@ -507,7 +503,11 @@ def forward(
507503
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
508504
}
509505

510-
hidden_states = inputs_embeds * (self.config.hidden_size**0.5)
506+
hidden_states = inputs_embeds
507+
508+
# Apply muP input scaling if enabled
509+
if self.config.mup_enabled:
510+
hidden_states = hidden_states * (self.config.hidden_size**0.5)
511511

512512
position_embeddings = self.rotary_emb(hidden_states, position_ids)
513513

tests/models/afmoe/__init__.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1 @@
1-
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
141

0 commit comments

Comments
 (0)