Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
config.n_routed_experts,
bias=False,
quant_config=None,
params_dtype=torch.float32,
prefix=f"{prefix}.gate")

self.gate.e_score_correction_bias = nn.Parameter(
Expand Down Expand Up @@ -180,7 +181,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The explicit cast hidden_states.to(dtype=torch.float32) is redundant and introduces an unnecessary memory copy, which can negatively impact performance.

Since self.gate.weight is already of dtype=torch.float32 (due to the change in __init__), torch.nn.functional.linear (which is called internally by ColumnParallelLinear) will automatically perform the matrix multiplication in float32 by upcasting the hidden_states tensor. This implicit type promotion is more efficient than an explicit cast.

Removing the explicit cast will rely on this standard PyTorch behavior and avoid the overhead, while still achieving the goal of performing the gate computation in float32.

Suggested change
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
router_logits, _ = self.gate(hidden_states)

final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
Expand Down
Loading