diff --git a/pyproject.toml b/pyproject.toml index 7e6b11ce4..7df209686 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "truss" -version = "0.11.12" +version = "0.11.13.rc1" description = "A seamless bridge from model development to model delivery" authors = [ { name = "Pankaj Gupta", email = "no-reply@baseten.co" }, diff --git a/truss/base/trt_llm_config.py b/truss/base/trt_llm_config.py index 444f61310..0aa5aab6c 100644 --- a/truss/base/trt_llm_config.py +++ b/truss/base/trt_llm_config.py @@ -68,6 +68,7 @@ class TrussTRTLLMQuantizationType(str, Enum): FP8_KV = "fp8_kv" FP4 = "fp4" FP4_KV = "fp4_kv" + FP4_MLP_ONLY = "fp4_mlp_only" class TrussTRTLLMPluginConfiguration(PydanticTrTBaseModel): @@ -713,7 +714,9 @@ def trt_llm_common_validation(config: "TrussConfig"): "accelerators or newer (CUDA_COMPUTE>=89)" ) elif trt_llm_config.build.quantization_type in [ - TrussTRTLLMQuantizationType.FP4 + TrussTRTLLMQuantizationType.FP4, + TrussTRTLLMQuantizationType.FP4_KV, + TrussTRTLLMQuantizationType.FP4_MLP_ONLY, ] and config.resources.accelerator.accelerator in [ truss_config.Accelerator.H100, truss_config.Accelerator.L4,