diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index be3d51a025ed..017a339ffca0 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -606,7 +606,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4MoeForCausalLM` | GLM-4.5 | T + IE+ + VE+ | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
-| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V-Air`, etc. | ✅︎ | ✅︎ | ✅︎ |
+| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
diff --git a/tests/models/registry.py b/tests/models/registry.py
index d86bd20fb0e3..47057d32e9cd 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -383,7 +383,7 @@ def check_available_online(
trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), # noqa: E501
- "Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V-Air",
+ "Glm4v_moeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V",
is_available_online=False), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
trust_remote_code=True,
diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py
index c702684c6caa..bd3e27662ee7 100644
--- a/vllm/model_executor/models/glm4_moe.py
+++ b/vllm/model_executor/models/glm4_moe.py
@@ -123,6 +123,7 @@ def __init__(
config.n_routed_experts,
bias=False,
quant_config=None,
+ params_dtype=torch.float32,
prefix=f"{prefix}.gate")
self.gate.e_score_correction_bias = nn.Parameter(
@@ -180,7 +181,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
- router_logits, _ = self.gate(hidden_states)
+ router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor