Skip to content

Commit ecb7438

Browse files
committed
Fix expert_bias routing in AFMoE
1 parent 826cb12 commit ecb7438

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

src/transformers/models/afmoe/modeling_afmoe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,16 @@ def __init__(self, config):
176176
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
177177

178178
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor | None = None):
179-
# Keep expert_bias argument for checkpoint/backwards compatibility (it is always zero in released models).
180-
del expert_bias
181179
_, _, hidden_dim = hidden_states.shape
182180
hidden_states = hidden_states.view(-1, hidden_dim)
183181

184182
scores = torch.sigmoid(self.gate(hidden_states).to(torch.float32))
185-
top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)
183+
184+
if expert_bias is not None:
185+
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
186+
top_scores = scores.gather(dim=1, index=selected_experts)
187+
else:
188+
top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)
186189

187190
if self.route_norm:
188191
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20

src/transformers/models/afmoe/modular_afmoe.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,16 @@ def __init__(self, config):
133133
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
134134

135135
def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor | None = None):
136-
# Keep expert_bias argument for checkpoint/backwards compatibility (it is always zero in released models).
137-
del expert_bias
138136
_, _, hidden_dim = hidden_states.shape
139137
hidden_states = hidden_states.view(-1, hidden_dim)
140138

141139
scores = torch.sigmoid(self.gate(hidden_states).to(torch.float32))
142-
top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)
140+
141+
if expert_bias is not None:
142+
_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
143+
top_scores = scores.gather(dim=1, index=selected_experts)
144+
else:
145+
top_scores, selected_experts = torch.topk(scores, k=self.top_k, dim=1)
143146

144147
if self.route_norm:
145148
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20

0 commit comments

Comments
 (0)