Skip to content

Commit 8d78a29

Browse files
committed
Address PR review feedback for AFMoE model
1 parent 46ca8d5 commit 8d78a29

File tree

4 files changed

+13
-979
lines changed

4 files changed

+13
-979
lines changed

src/transformers/models/afmoe/configuration_afmoe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ class AfmoeConfig(PreTrainedConfig):
103103
`global_attn_every_n_layers`.
104104
attention_dropout (`float`, *optional*, defaults to 0.0):
105105
The dropout ratio for the attention probabilities.
106+
mup_enabled (`bool`, *optional*, defaults to `False`):
107+
Whether to enable muP (Maximal Update Parametrization) input scaling. When enabled, input embeddings
108+
are scaled by `sqrt(hidden_size)`.
106109
107110
Example:
108111
```python
@@ -157,6 +160,7 @@ def __init__(
157160
sliding_window: Optional[int] = 1024,
158161
layer_types: Optional[list] = None,
159162
attention_dropout: Optional[float] = 0.0,
163+
mup_enabled: Optional[bool] = False,
160164
**kwargs,
161165
):
162166
self.vocab_size = vocab_size
@@ -187,6 +191,7 @@ def __init__(
187191
self.attention_dropout = attention_dropout
188192
self.global_attn_every_n_layers = global_attn_every_n_layers
189193
self.sliding_window = sliding_window
194+
self.mup_enabled = mup_enabled
190195
self.layer_types = layer_types
191196
if self.layer_types is None:
192197
self.layer_types = [

src/transformers/models/afmoe/modular_afmoe.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,14 @@ def __init__(self, config):
132132
self.route_scale = config.route_scale
133133
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
134134

135-
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor | None = None):
135+
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor):
136136
_, _, hidden_dim = hidden_states.shape
137137
hidden_states = hidden_states.view(-1, hidden_dim)
138138

139139
scores = torch.sigmoid(self.gate(hidden_states).to(torch.float32))
140140

141-
if expert_bias is not None:
142-
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
143-
top_scores = scores.gather(dim=1, index=selected_experts)
144-
else:
145-
top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)
141+
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
142+
top_scores = scores.gather(dim=1, index=selected_experts)
146143

147144
if self.route_norm:
148145
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
@@ -159,8 +156,6 @@ class AfmoeExperts(nn.ModuleList):
159156
This mirrors the Experts pattern used across other MoE models to ease checkpoint conversion.
160157
"""
161158

162-
_checkpoint_conversion_mapping = {"experts": "experts"}
163-
164159
def __init__(self, config: AfmoeConfig):
165160
super().__init__()
166161
self.top_k = config.num_experts_per_tok
@@ -507,7 +502,11 @@ def forward(
507502
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
508503
}
509504

510-
hidden_states = inputs_embeds * (self.config.hidden_size**0.5)
505+
hidden_states = inputs_embeds
506+
507+
# Apply muP input scaling if enabled
508+
if self.config.mup_enabled:
509+
hidden_states = hidden_states * (self.config.hidden_size**0.5)
511510

512511
position_embeddings = self.rotary_emb(hidden_states, position_ids)
513512

tests/models/afmoe/__init__.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1 @@
1-
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
141

0 commit comments

Comments
 (0)