Skip to content

Commit 72b1c2a

Browse files
authored
[Bugfix] Use latency MOE backend as default for Flashinfer and other misc fixes (#27439)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
1 parent e0919f3 commit 72b1c2a

File tree

7 files changed

+47
-12
lines changed

7 files changed

+47
-12
lines changed

csrc/quantization/fp4/nvfp4_quant_kernels.cu

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131

3232
namespace vllm {
3333

34+
template <typename Int>
35+
__host__ __device__ inline Int round_up(Int x, Int y) {
36+
static_assert(std::is_integral_v<Int>,
37+
"round_up argument must be integral type");
38+
return (x + y - 1) / y * y;
39+
}
40+
3441
// Use UE4M3 by default.
3542
template <class Type, bool UE8M0_SF = false>
3643
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
@@ -42,10 +49,21 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
4249
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
4350
"Vec size is not matched.");
4451

52+
int sf_m = round_up<int>(numRows, 128);
53+
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE;
54+
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4;
55+
for (int row = numRows + blockIdx.x; row < sf_m; row += gridDim.x) {
56+
// Each thread writes 4 uint32_t elements.
57+
for (int col = sf_n_unpadded + threadIdx.x * 4; col < sf_n_int;
58+
col += blockDim.x * 4) {
59+
SFout[row * sf_n_int + col] = 0x00;
60+
}
61+
}
62+
4563
// Get the global scaling factor, which will be applied to the SF.
4664
// Note SFScale is the same as next GEMM's alpha, which is
4765
// (448.f / (Alpha_A / 6.f)).
48-
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
66+
float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0];
4967

5068
// Input tensor row/col loops.
5169
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
@@ -64,7 +82,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
6482
rowIdx, colIdx, numCols, SFout);
6583

6684
out_pos =
67-
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
85+
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out);
6886
}
6987
}
7088
}

tests/kernels/quantization/test_nvfp4_quant.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,7 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
168168
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
169169

170170
out, out_scale = ops.scaled_fp4_quant(x, global_scale)
171-
172171
scale_ans = recover_swizzled_scales(out_scale, m, n)
173172
out_ans = cast_from_fp4(out, m, n)
174-
175173
torch.testing.assert_close(out_ans, out_ref)
176174
torch.testing.assert_close(scale_ans, scale_ref)

vllm/_custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1385,7 +1385,7 @@ def scaled_fp4_quant(
13851385
rounded_m = round_up(m, 128)
13861386
scale_n = n // block_size
13871387
rounded_n = round_up(scale_n, 4)
1388-
output_scale = torch.zeros(
1388+
output_scale = torch.empty(
13891389
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
13901390
)
13911391

vllm/envs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@
155155
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
156156
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
157157
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
158-
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "throughput"
158+
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "latency"
159159
VLLM_XGRAMMAR_CACHE_MB: int = 0
160160
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
161161
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
@@ -1218,7 +1218,7 @@ def get_vllm_port() -> int | None:
12181218
# - "latency":
12191219
# Uses TensorRT-LLM kernels optimized for low-latency inference.
12201220
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
1221-
"VLLM_FLASHINFER_MOE_BACKEND", "throughput", ["throughput", "latency"]
1221+
"VLLM_FLASHINFER_MOE_BACKEND", "latency", ["throughput", "latency"]
12221222
),
12231223
# Control the maximum number of tokens per expert supported by the
12241224
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
@@ -1325,7 +1325,7 @@ def get_vllm_port() -> int | None:
13251325
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
13261326
"VLLM_NVFP4_GEMM_BACKEND",
13271327
None,
1328-
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass"],
1328+
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", "cutlass"],
13291329
),
13301330
# Controls garbage collection during CUDA graph capture.
13311331
# If set to 0 (default), enables GC freezing to speed up capture time.

vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def __init__(self):
5050
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
5151
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
5252
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
53+
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
54+
self.backend = "cutlass"
55+
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
5356

5457
if self.backend == "none":
5558
raise ValueError(

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,13 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
138138
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
139139
return Fp8MoeBackend.FLASHINFER_TRTLLM
140140
else:
141+
if block_quant:
142+
raise ValueError(
143+
"FlashInfer FP8 MoE throughput backend does not "
144+
"support block quantization. Please use "
145+
"VLLM_FLASHINFER_MOE_BACKEND=latency "
146+
"instead."
147+
)
141148
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM100")
142149
return Fp8MoeBackend.FLASHINFER_CUTLASS
143150

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,10 @@ def is_layer_excluded(self, prefix: str) -> bool:
221221
def get_quant_method(
222222
self, layer: torch.nn.Module, prefix: str
223223
) -> Optional["QuantizeMethodBase"]:
224-
from vllm.attention.layer import Attention # Avoid circular import
224+
from vllm.attention.layer import ( # Avoid circular import
225+
Attention,
226+
MLAAttention,
227+
)
225228

226229
if isinstance(layer, LinearBase):
227230
if self.is_layer_excluded(prefix):
@@ -230,7 +233,7 @@ def get_quant_method(
230233
if "vision_tower" in prefix or "vision_model" in prefix:
231234
return UnquantizedLinearMethod()
232235
return ModelOptFp8LinearMethod(self)
233-
elif isinstance(layer, Attention):
236+
elif isinstance(layer, (Attention, MLAAttention)):
234237
return ModelOptFp8KVCacheMethod(self)
235238
elif isinstance(layer, FusedMoE):
236239
return ModelOptFp8MoEMethod(self, layer)
@@ -888,7 +891,10 @@ def is_layer_excluded(self, prefix: str) -> bool:
888891
def get_quant_method(
889892
self, layer: torch.nn.Module, prefix: str
890893
) -> Optional["QuantizeMethodBase"]:
891-
from vllm.attention.layer import Attention # Avoid circular import
894+
from vllm.attention.layer import ( # Avoid circular import
895+
Attention,
896+
MLAAttention,
897+
)
892898

893899
skip_layer = self.is_layer_excluded(prefix)
894900
if isinstance(layer, LinearBase):
@@ -898,7 +904,7 @@ def get_quant_method(
898904
if "vision_tower" in prefix or "vision_model" in prefix:
899905
return UnquantizedLinearMethod()
900906
return ModelOptNvFp4LinearMethod(self)
901-
elif isinstance(layer, Attention):
907+
elif isinstance(layer, (Attention, MLAAttention)):
902908
return ModelOptFp8KVCacheMethod(self)
903909
elif isinstance(layer, FusedMoE):
904910
if skip_layer:
@@ -941,6 +947,9 @@ def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
941947
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
942948
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
943949
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
950+
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
951+
self.backend = "cutlass"
952+
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
944953

945954
if self.backend == "none":
946955
raise ValueError(

0 commit comments

Comments
 (0)