Skip to content

Commit aa8f985

Browse files
authored
Merge pull request #925 from kvcache-ai/fix-gate-compile
fix-gate-compile
2 parents e788248 + 1149953 commit aa8f985

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

ktransformers/operators/gate.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def unload(self):
125125

126126
# adapted from https://github.com/vllm-project/vllm/blob/c77620d22d43daa7e0440e6267cbdd83f849ac64/vllm/model_executor/layers/fused_moe/fused_moe.py#L1071
127127
# This is used by the Deepseek-V2 and Deepseek-V3 model
128-
#@torch.compile(dynamic=True)
128+
@torch.compile(dynamic=True)
129129
def grouped_topk(hidden_states: torch.Tensor,
130130
gating_output: torch.Tensor,
131131
topk: int,
@@ -225,9 +225,8 @@ def forward(self, hidden_states) -> torch.Tensor:
225225
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
226226
)
227227

228-
return grouped_topk(hidden_states, logits,
229-
self.top_k, self.norm_topk_prob,
230-
self.n_group, self.topk_group)
228+
return grouped_topk(hidden_states, logits, self.top_k, self.norm_topk_prob,
229+
self.n_group, self.topk_group, "sigmoid", self.e_score_correction_bias)
231230

232231
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
233232
if device is None: device = self.device

0 commit comments

Comments
 (0)