Skip to content

Conversation

@alyosha-swamy
Copy link
Contributor

Summary

This PR adds support for the AFMoE (Arcee Foundational Mixture of Experts) model architecture for the upcoming Trinity-Mini and Trinity-Nano releases. AFMoE is a decoder-only transformer model featuring a sparse Mixture of Experts (MoE) approach, combining token-choice routing with shared experts and several architectural innovations for efficient inference and improved performance.

Model Description

AFMoE features the following key architectural components:

  • Mixture of Experts with Shared Experts: Combines routed experts (activated per-token via learned routing) with always-active shared experts for stable base computation

  • Token-Choice Routing: Uses sigmoid or softmax-based routing with normalization and scaling for expert selection

  • Q/K Normalization and Gating: Applies RMSNorm to query and key projections and uses sigmoid gating on attention outputs for improved training stability

  • Hybrid Attention Patterns: Alternates between sliding window attention and full attention across layers for efficiency with long contexts

  • Dual Normalization: Uses pre- and post-normalization around both attention and MLP blocks for training stability

  • Configurable Dense Layers: Allows initial layers to use dense MLPs before transitioning to sparse MoE layers (num_dense_layers)

Implementation Details

  • Modular implementation leveraging transformers' modular architecture:

    • Efficient AfmoeRMSNorm for layer normalization

    • AfmoeRotaryEmbedding for positional encoding

    • AfmoeAttention class implementing Q/K normalization and output gating

    • AfmoeTokenChoiceRouter for expert selection

    • AfmoeMoE class implementing shared + routed experts architecture

    • AfmoeDecoderLayer integrating attention and MoE blocks with dual normalization

Testing

  • Added comprehensive test suite following standard transformers test patterns
  • Tests for core functionality:
    • Model initialization and weight loading
    • Forward and backward passes
    • Attention mechanism (sliding window + full attention patterns)
    • MoE routing and expert selection
    • RoPE embeddings
    • KV cache compatibility
  • Integration tests with example checkpoints
  • Verified compatibility with existing transformer infrastructure
  • Model loading and inference verified with arcee-ai/Trinity-Mini

Documentation

  • Comprehensive model documentation in docs/source/en/model_doc/afmoe.md
  • Detailed architecture descriptions and usage examples
  • All configuration parameters documented with clear descriptions
  • Example code for both Pipeline and AutoModel usage patterns

@alyosha-swamy alyosha-swamy force-pushed the add_afmoe_model branch 4 times, most recently from 6b08d17 to e3ad5e9 Compare November 12, 2025 19:23
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice work!

Comment on lines 47 to 92
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids: Optional[torch.Tensor] = None, unsqueeze_dim: int = 1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these can also be imported from Llama! 😉

top_scores, selected_experts = self.router(hidden_states, self.expert_bias)

# Process through shared experts
if self.shared_experts is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment, is this used by the released model or not?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not adressed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the first layer is a standard dense FFN, and all subsequent layers use the MoE block

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case the arch should be different! Use normal MLP for mlp and expert for expert! 🤗
You can set the first layer then just do += we want to avoid codepathes as much as possible

Comment on lines +339 to +344
# MoE or dense FFN
self.moe_enabled = layer_idx >= config.num_dense_layers
if self.moe_enabled:
self.mlp = AfmoeMoE(config)
else:
self.mlp = AfmoeMLP(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is moe disabled on any of the released ckpts? 🤗

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again here

@alyosha-swamy alyosha-swamy force-pushed the add_afmoe_model branch 2 times, most recently from 8c6bdb4 to 045776d Compare November 21, 2025 16:42
This mirrors the Experts pattern used across other MoE models to ease checkpoint conversion.
"""

_checkpoint_conversion_mapping = {"experts": "experts"}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_checkpoint_conversion_mapping = {"experts": "experts"}

top_scores, selected_experts = self.router(hidden_states, self.expert_bias)

# Process through shared experts
if self.shared_experts is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not adressed

key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)

if self.is_local_attention:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i did not get an answer

Comment on lines +339 to +344
# MoE or dense FFN
self.moe_enabled = layer_idx >= config.num_dense_layers
if self.moe_enabled:
self.mlp = AfmoeMoE(config)
else:
self.mlp = AfmoeMLP(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again here

tests_output.txt Outdated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we tend to try and remove code path as much as possible, if not done here we 'll do it post release !

_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
top_scores = scores.gather(dim=1, index=selected_experts)

if self.route_norm:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this always True or False? (cf removing code path :)

return top_scores, selected_experts


class AfmoeExperts(nn.ModuleList):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could just inherti from Mixtral or Qwen2Moe it should be the same no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkpoint weight structure is different in AFMoE

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have an online weight converter but now worries :)

Comment on lines 418 to 427
if isinstance(module, nn.Linear):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight[module.padding_idx].zero_()
elif isinstance(module, AfmoeRMSNorm):
module.weight.fill_(1.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should not be used, can you use nn.init instead please ! one of the ci will fail as we require this for inits!

@ArthurZucker
Copy link
Collaborator

  • FAILED tests/models/afmoe/test_modeling_afmoe.py::AfmoeModelTest::test_attention_outputs - TypeError: object of type 'NoneType' has no len()
  • FAILED tests/models/afmoe/test_modeling_afmoe.py::AfmoeModelTest::test_prompt_lookup_decoding_matches_greedy_search - TypeError: 'NoneType' object is not subscriptable
  • FAILED tests/models/afmoe/test_modeling_afmoe.py::AfmoeModelTest::test_sample_generate_dict_output - AssertionError: Lists differ: [False, False, False] != [True, True, True]
    I think the output attention recorder is wrong. Once fixed we can merge

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Remove shared expert if else as defaults to 2
Remove `route_norm` as it default to `True`.

Make test smaller faster
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: afmoe, auto

@ArthurZucker ArthurZucker merged commit cac0a28 into huggingface:main Nov 29, 2025
16 of 21 checks passed
@LysandreJik
Copy link
Member

It seems like this model wasn't added to src/transformers/models/__init__.py ?

@Rocketknight1
Copy link
Member

Seems like it - how did the CI pass?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants