@@ -221,7 +221,10 @@ def is_layer_excluded(self, prefix: str) -> bool:
221221 def get_quant_method (
222222 self , layer : torch .nn .Module , prefix : str
223223 ) -> Optional ["QuantizeMethodBase" ]:
224- from vllm .attention .layer import Attention # Avoid circular import
224+ from vllm .attention .layer import ( # Avoid circular import
225+ Attention ,
226+ MLAAttention ,
227+ )
225228
226229 if isinstance (layer , LinearBase ):
227230 if self .is_layer_excluded (prefix ):
@@ -230,7 +233,7 @@ def get_quant_method(
230233 if "vision_tower" in prefix or "vision_model" in prefix :
231234 return UnquantizedLinearMethod ()
232235 return ModelOptFp8LinearMethod (self )
233- elif isinstance (layer , Attention ):
236+ elif isinstance (layer , ( Attention , MLAAttention ) ):
234237 return ModelOptFp8KVCacheMethod (self )
235238 elif isinstance (layer , FusedMoE ):
236239 return ModelOptFp8MoEMethod (self , layer )
@@ -888,7 +891,10 @@ def is_layer_excluded(self, prefix: str) -> bool:
888891 def get_quant_method (
889892 self , layer : torch .nn .Module , prefix : str
890893 ) -> Optional ["QuantizeMethodBase" ]:
891- from vllm .attention .layer import Attention # Avoid circular import
894+ from vllm .attention .layer import ( # Avoid circular import
895+ Attention ,
896+ MLAAttention ,
897+ )
892898
893899 skip_layer = self .is_layer_excluded (prefix )
894900 if isinstance (layer , LinearBase ):
@@ -898,7 +904,7 @@ def get_quant_method(
898904 if "vision_tower" in prefix or "vision_model" in prefix :
899905 return UnquantizedLinearMethod ()
900906 return ModelOptNvFp4LinearMethod (self )
901- elif isinstance (layer , Attention ):
907+ elif isinstance (layer , ( Attention , MLAAttention ) ):
902908 return ModelOptFp8KVCacheMethod (self )
903909 elif isinstance (layer , FusedMoE ):
904910 if skip_layer :
@@ -941,6 +947,9 @@ def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
941947 elif envs .VLLM_NVFP4_GEMM_BACKEND .startswith ("flashinfer-" ):
942948 self .backend = envs .VLLM_NVFP4_GEMM_BACKEND
943949 assert has_flashinfer (), f"FlashInfer is required for { self .backend } "
950+ elif envs .VLLM_NVFP4_GEMM_BACKEND == "cutlass" :
951+ self .backend = "cutlass"
952+ assert cutlass_fp4_supported (), f"Cutlass is required for { self .backend } "
944953
945954 if self .backend == "none" :
946955 raise ValueError (
0 commit comments