Skip to content

Commit 4c57bde

Browse files
zRzRzRzRzRzRzRpaulpak58
authored andcommitted
self.gate dtype update for GLM-4.5 (vllm-project#22203)
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
1 parent 4c8daae commit 4c57bde

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
607607
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
608608
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
609609
| `Glm4MoeForCausalLM` | GLM-4.5 | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
610-
| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V-Air`, etc. | ✅︎ | ✅︎ | ✅︎ |
610+
| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
611611
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
612612
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
613613
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |

tests/models/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def check_available_online(
385385
trust_remote_code=True,
386386
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
387387
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), # noqa: E501
388-
"Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V-Air",
388+
"Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V",
389389
is_available_online=False), # noqa: E501
390390
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
391391
trust_remote_code=True,

vllm/model_executor/models/glm4_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(
123123
config.n_routed_experts,
124124
bias=False,
125125
quant_config=None,
126+
params_dtype=torch.float32,
126127
prefix=f"{prefix}.gate")
127128

128129
self.gate.e_score_correction_bias = nn.Parameter(
@@ -180,7 +181,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
180181

181182
if self.n_shared_experts is not None:
182183
shared_output = self.shared_experts(hidden_states)
183-
router_logits, _ = self.gate(hidden_states)
184+
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
184185
final_hidden_states = self.experts(
185186
hidden_states=hidden_states,
186187
router_logits=router_logits) * self.routed_scaling_factor

0 commit comments

Comments
 (0)