Skip to content

Commit 02640f4

Browse files
committed
Address PR review feedback for AFMoE model
1 parent 46ca8d5 commit 02640f4

File tree

5 files changed

+22
-996
lines changed

5 files changed

+22
-996
lines changed

src/transformers/models/afmoe/configuration_afmoe.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Optional
1818

1919
from ...configuration_utils import PreTrainedConfig, layer_type_validation
20-
from ...modeling_rope_utils import rope_config_validation, standardize_rope_params
2120
from ...utils import logging
2221

2322

@@ -103,6 +102,9 @@ class AfmoeConfig(PreTrainedConfig):
103102
`global_attn_every_n_layers`.
104103
attention_dropout (`float`, *optional*, defaults to 0.0):
105104
The dropout ratio for the attention probabilities.
105+
mup_enabled (`bool`, *optional*, defaults to `False`):
106+
Whether to enable muP (Maximal Update Parametrization) input scaling. When enabled, input embeddings
107+
are scaled by `sqrt(hidden_size)`.
106108
107109
Example:
108110
```python
@@ -157,6 +159,7 @@ def __init__(
157159
sliding_window: Optional[int] = 1024,
158160
layer_types: Optional[list] = None,
159161
attention_dropout: Optional[float] = 0.0,
162+
mup_enabled: Optional[bool] = False,
160163
**kwargs,
161164
):
162165
self.vocab_size = vocab_size
@@ -187,6 +190,7 @@ def __init__(
187190
self.attention_dropout = attention_dropout
188191
self.global_attn_every_n_layers = global_attn_every_n_layers
189192
self.sliding_window = sliding_window
193+
self.mup_enabled = mup_enabled
190194
self.layer_types = layer_types
191195
if self.layer_types is None:
192196
self.layer_types = [
@@ -200,13 +204,6 @@ def __init__(
200204

201205
self.num_key_value_heads = num_key_value_heads
202206

203-
# Setup and validate rope configs
204-
self.rope_parameters = rope_scaling
205-
standardize_rope_params(self, rope_theta=rope_theta)
206-
if self.rope_scaling is not None and "type" in self.rope_scaling:
207-
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
208-
rope_config_validation(self)
209-
210207
super().__init__(
211208
tie_word_embeddings=tie_word_embeddings,
212209
**kwargs,

src/transformers/models/afmoe/modeling_afmoe.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,17 +175,14 @@ def __init__(self, config):
175175
self.route_scale = config.route_scale
176176
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
177177

178-
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor | None = None):
178+
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor):
179179
_, _, hidden_dim = hidden_states.shape
180180
hidden_states = hidden_states.view(-1, hidden_dim)
181181

182182
scores = torch.sigmoid(self.gate(hidden_states).to(torch.float32))
183183

184-
if expert_bias is not None:
185-
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
186-
top_scores = scores.gather(dim=1, index=selected_experts)
187-
else:
188-
top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)
184+
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
185+
top_scores = scores.gather(dim=1, index=selected_experts)
189186

190187
if self.route_norm:
191188
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
@@ -202,8 +199,6 @@ class AfmoeExperts(nn.ModuleList):
202199
This mirrors the Experts pattern used across other MoE models to ease checkpoint conversion.
203200
"""
204201

205-
_checkpoint_conversion_mapping = {"experts": "experts"}
206-
207202
def __init__(self, config: AfmoeConfig):
208203
super().__init__()
209204
self.top_k = config.num_experts_per_tok
@@ -376,6 +371,7 @@ def __init__(self, config: AfmoeConfig, layer_idx: int):
376371
self.o_proj = nn.Linear(
377372
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
378373
)
374+
self.rotary_fn = apply_rotary_pos_emb
379375
# Parent LlamaAttention already sets: layer_idx, num_heads, num_key_value_heads, num_key_value_groups, head_dim
380376
# We only add AFMoE-specific attributes
381377
self.is_local_attention = config.layer_types[layer_idx] == "sliding_attention"
@@ -628,7 +624,11 @@ def forward(
628624
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
629625
}
630626

631-
hidden_states = inputs_embeds * (self.config.hidden_size**0.5)
627+
hidden_states = inputs_embeds
628+
629+
# Apply muP input scaling if enabled
630+
if self.config.mup_enabled:
631+
hidden_states = hidden_states * (self.config.hidden_size**0.5)
632632

633633
position_embeddings = self.rotary_emb(hidden_states, position_ids)
634634

src/transformers/models/afmoe/modular_afmoe.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,14 @@ def __init__(self, config):
132132
self.route_scale = config.route_scale
133133
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
134134

135-
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor | None = None):
135+
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor):
136136
_, _, hidden_dim = hidden_states.shape
137137
hidden_states = hidden_states.view(-1, hidden_dim)
138138

139139
scores = torch.sigmoid(self.gate(hidden_states).to(torch.float32))
140140

141-
if expert_bias is not None:
142-
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
143-
top_scores = scores.gather(dim=1, index=selected_experts)
144-
else:
145-
top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)
141+
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
142+
top_scores = scores.gather(dim=1, index=selected_experts)
146143

147144
if self.route_norm:
148145
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
@@ -159,8 +156,6 @@ class AfmoeExperts(nn.ModuleList):
159156
This mirrors the Experts pattern used across other MoE models to ease checkpoint conversion.
160157
"""
161158

162-
_checkpoint_conversion_mapping = {"experts": "experts"}
163-
164159
def __init__(self, config: AfmoeConfig):
165160
super().__init__()
166161
self.top_k = config.num_experts_per_tok
@@ -507,7 +502,11 @@ def forward(
507502
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
508503
}
509504

510-
hidden_states = inputs_embeds * (self.config.hidden_size**0.5)
505+
hidden_states = inputs_embeds
506+
507+
# Apply muP input scaling if enabled
508+
if self.config.mup_enabled:
509+
hidden_states = hidden_states * (self.config.hidden_size**0.5)
511510

512511
position_embeddings = self.rotary_emb(hidden_states, position_ids)
513512

tests/models/afmoe/__init__.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1 @@
1-
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
141

0 commit comments

Comments
 (0)