From f6170b01e1fb936ec423f6aa0d0b15095648f867 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 6 Apr 2025 01:13:12 +0000 Subject: [PATCH 01/14] update transformers for llama 4 Signed-off-by: Qubitium --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f5e6e7d37..ef92b2c05 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ datasets>=3.2.0 numpy>=1.26.4 torch>=2.2.0 safetensors>=0.5.2 -transformers>=4.49.0 +transformers>=4.51.0 threadpoolctl>=3.6.0 packaging>=24.2 device-smi==0.4.1 From bb89c7607065fb908946ac54d26d29503d815166 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sun, 6 Apr 2025 12:43:57 +0800 Subject: [PATCH 02/14] add Llama4GPTQ --- gptqmodel/models/auto.py | 2 ++ gptqmodel/models/definitions/__init__.py | 1 + gptqmodel/models/definitions/llama4.py | 37 ++++++++++++++++++++++++ tests/models/test_llama4.py | 0 4 files changed, 40 insertions(+) create mode 100644 gptqmodel/models/definitions/llama4.py create mode 100644 tests/models/test_llama4.py diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index c2454807c..2c76de52d 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -89,6 +89,7 @@ from .definitions.internlm import InternLMGPTQ # noqa: E402 from .definitions.internlm2 import InternLM2GPTQ # noqa: E402 from .definitions.llama import LlamaGPTQ # noqa: E402 +from .definitions.llama4 import Llama4GPTQ # noqa: E402 from .definitions.longllama import LongLlamaGPTQ # noqa: E402 from .definitions.minicpm import MiniCPMGPTQ # noqa: E402 from .definitions.minicpm3 import MiniCPM3GPTQ # noqa: E402 @@ -128,6 +129,7 @@ "gptj": GPTJGPTQ, "gpt2": GPT2GPTQ, "llama": LlamaGPTQ, + "llama4": Llama4GPTQ, "opt": OPTGPTQ, "moss": MOSSGPTQ, "chatglm": ChatGLM, diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py index c2dac053c..a23e87e93 100644 --- a/gptqmodel/models/definitions/__init__.py +++ b/gptqmodel/models/definitions/__init__.py @@ -41,6 +41,7 @@ from .internlm import InternLMGPTQ from .internlm2 import InternLM2GPTQ from .llama import LlamaGPTQ +from .llama4 import Llama4GPTQ from .longllama import LongLlamaGPTQ from .minicpm3 import MiniCPM3GPTQ from .mistral import MistralGPTQ diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py new file mode 100644 index 000000000..faa4d3d8a --- /dev/null +++ b/gptqmodel/models/definitions/llama4.py @@ -0,0 +1,37 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..base import BaseGPTQModel + + +class Llama4GPTQ(BaseGPTQModel): + base_modules = ["language_model.model.embed_tokens", "language_model.model.norm"] + pre_lm_head_norm_module = "language_model.model.norm" + + layers_node = "language_model.model.layers" + layer_type = "Llama4TextDecoderLayer" + + layer_modules = [ + ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], + ["self_attn.o_proj"], + + + ["feed_forward.experts.gate_up_proj"], + ["feed_forward.experts.down_proj"], + + ["feed_forward.shared_expert.down_proj"], + ["feed_forward.shared_expert.down_proj"], + ] diff --git a/tests/models/test_llama4.py b/tests/models/test_llama4.py new file mode 100644 index 000000000..e69de29bb From 0ebfb175f488125f871ef4f26157056c0c96a719 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sun, 6 Apr 2025 14:45:59 +0800 Subject: [PATCH 03/14] use loader AutoModelForImageTextToText --- gptqmodel/models/definitions/llama4.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py index faa4d3d8a..07a596337 100644 --- a/gptqmodel/models/definitions/llama4.py +++ b/gptqmodel/models/definitions/llama4.py @@ -15,9 +15,11 @@ # limitations under the License. from ..base import BaseGPTQModel - +from transformers import AutoModelForImageTextToText class Llama4GPTQ(BaseGPTQModel): + loader = AutoModelForImageTextToText + base_modules = ["language_model.model.embed_tokens", "language_model.model.norm"] pre_lm_head_norm_module = "language_model.model.norm" From 006d7a6172fed9c881edfdd37b1cea88aed662c9 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sun, 6 Apr 2025 15:05:49 +0800 Subject: [PATCH 04/14] cleanup --- gptqmodel/models/definitions/llama4.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py index 07a596337..eb3f0dccf 100644 --- a/gptqmodel/models/definitions/llama4.py +++ b/gptqmodel/models/definitions/llama4.py @@ -30,10 +30,6 @@ class Llama4GPTQ(BaseGPTQModel): ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], ["self_attn.o_proj"], - - ["feed_forward.experts.gate_up_proj"], - ["feed_forward.experts.down_proj"], - - ["feed_forward.shared_expert.down_proj"], + ["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj"], ["feed_forward.shared_expert.down_proj"], ] From 26f074aacd5b2aadfe33a68e9c4fd6117b936d24 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Sun, 6 Apr 2025 15:52:04 +0800 Subject: [PATCH 05/14] fix qkvo forward when every 4 layer Signed-off-by: ZX-ModelCloud --- gptqmodel/models/definitions/llama4.py | 3 +-- tests/models/test_llama4.py | 29 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py index eb3f0dccf..76e7f64bd 100644 --- a/gptqmodel/models/definitions/llama4.py +++ b/gptqmodel/models/definitions/llama4.py @@ -27,8 +27,7 @@ class Llama4GPTQ(BaseGPTQModel): layer_type = "Llama4TextDecoderLayer" layer_modules = [ - ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], - ["self_attn.o_proj"], + ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"], ["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj"], ["feed_forward.shared_expert.down_proj"], diff --git a/tests/models/test_llama4.py b/tests/models/test_llama4.py index e69de29bb..047b8d4a4 100644 --- a/tests/models/test_llama4.py +++ b/tests/models/test_llama4.py @@ -0,0 +1,29 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from model_test import ModelTest + + +class TestLlama4(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Llama-4-Scout-17B-16E-Instruct" # "meta-llama/Llama-4-Scout-17B-16E-Instruct" + NATIVE_ARC_CHALLENGE_ACC = 0.3567 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3805 + QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.36 + APPLY_CHAT_TEMPLATE = True + TRUST_REMOTE_CODE = True + + def test_llama4(self): + self.quant_lm_eval() From 15e8c04fafca1bf82d51531c5fdf46e881f1a16e Mon Sep 17 00:00:00 2001 From: Qubitium-ModelCloud Date: Sun, 6 Apr 2025 16:22:20 +0800 Subject: [PATCH 06/14] Update llama4.py --- gptqmodel/models/definitions/llama4.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py index 76e7f64bd..d248b5684 100644 --- a/gptqmodel/models/definitions/llama4.py +++ b/gptqmodel/models/definitions/llama4.py @@ -29,6 +29,5 @@ class Llama4GPTQ(BaseGPTQModel): layer_modules = [ ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"], - ["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj"], - ["feed_forward.shared_expert.down_proj"], + ["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj", "feed_forward.shared_expert.down_proj"], ] From 10d330f54fee5a5d6e7f4d08bc8c5fbcd8d3045a Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sun, 6 Apr 2025 16:30:11 +0800 Subject: [PATCH 07/14] add support_batch_quantize --- gptqmodel/looper/module_looper.py | 6 ++++-- gptqmodel/models/base.py | 20 +++++++++++++++----- gptqmodel/models/definitions/llama4.py | 7 ++++++- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 04d075ea7..aca3190cd 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -40,6 +40,7 @@ class ModuleLooper(): def __init__(self, model: BaseGPTQModel, processors: List[LoopProcessor]): self.processors = processors self.gptq_model = model + self.support_batch_quantize = model.support_batch_quantize def cache_inputs(self, layers, auto_gc, calibration_data, calibration_enable_gpu_cache): layer_inputs = [] @@ -292,7 +293,8 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal mask = attention_masks[j] layer_attention_mask = mask if mask is None else move_to(mask, device=cur_layer_device) - additional_layer_inputs = {"attention_mask": layer_attention_mask} + additional_layer_inputs = {"attention_mask": layer_attention_mask} if self.support_batch_quantize else {} + layer_position_ids = ( None if not position_ids else move_to(position_ids[j], device=cur_layer_device) ) @@ -371,7 +373,7 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal mask = attention_masks[j] layer_attention_mask = mask if mask is None else move_to(mask, device=cur_layer_device) - additional_layer_inputs = {"attention_mask": layer_attention_mask} + additional_layer_inputs = {"attention_mask": layer_attention_mask} if self.support_batch_quantize else {} layer_position_ids = None if not position_ids else move_to(position_ids[j], device=cur_layer_device) if layer_position_ids is not None: additional_layer_inputs["position_ids"] = layer_position_ids diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 12aa23c13..9691f3c8e 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -28,6 +28,7 @@ from packaging import version from packaging.version import Version from tokenicer import Tokenicer +from torch import LongTensor from transformers import (AutoModelForCausalLM, AutoProcessor, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, modeling_utils) @@ -123,6 +124,8 @@ class BaseGPTQModel(nn.Module): server = None + support_batch_quantize = True + def __init__( self, model: PreTrainedModel, @@ -306,11 +309,18 @@ def _convert_tensor_to_list(tensor): new_calibration_dataset = concatenated_data - new_calibration_dataset_batched = [ - collate_data(new_calibration_dataset[start: start + batch_size], self.tokenizer.pad_token_id) - for start in range(0, len(new_calibration_dataset), batch_size) - ] - + if self.support_batch_quantize: + new_calibration_dataset_batched = [ + collate_data(new_calibration_dataset[start: start + batch_size], self.tokenizer.pad_token_id) + for start in range(0, len(new_calibration_dataset), batch_size) + ] + else: + new_calibration_dataset_batched = [ + { + "input_ids": torch.cat([LongTensor(block["input_ids"]) for block in new_calibration_dataset[start: start + batch_size]], dim=0).long(), + } + for start in range(0, len(new_calibration_dataset), batch_size) + ] return new_calibration_dataset_batched diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py index d248b5684..3e76f5acb 100644 --- a/gptqmodel/models/definitions/llama4.py +++ b/gptqmodel/models/definitions/llama4.py @@ -14,10 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..base import BaseGPTQModel from transformers import AutoModelForImageTextToText +from ..base import BaseGPTQModel + + class Llama4GPTQ(BaseGPTQModel): + # some bug in the attention_mask of transformers.modeling_llama4, + # so batch quantization for Llama4 is temporarily not supported. + support_batch_quantize = False loader = AutoModelForImageTextToText base_modules = ["language_model.model.embed_tokens", "language_model.model.norm"] From 4b18850a3d145e0bf49237d629b62212f0fae14e Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sun, 6 Apr 2025 16:39:55 +0800 Subject: [PATCH 08/14] add warning --- gptqmodel/models/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 9691f3c8e..0f0c73003 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -316,10 +316,8 @@ def _convert_tensor_to_list(tensor): ] else: new_calibration_dataset_batched = [ - { - "input_ids": torch.cat([LongTensor(block["input_ids"]) for block in new_calibration_dataset[start: start + batch_size]], dim=0).long(), - } - for start in range(0, len(new_calibration_dataset), batch_size) + {"input_ids": LongTensor(block["input_ids"]).unsqueeze(0)} + for block in new_calibration_dataset ] return new_calibration_dataset_batched @@ -358,6 +356,10 @@ def quantize( "FORMAT.MARLIN is deprecated for quantization. Please switch to FORMAT.GPTQ. GPTQMOdel will auto-use Marlin kernel for accelerated inference for FORMAT.GPTQ." ) + if self.support_batch_quantize is False: + batch_size = 1 + log.warn("Batch quantization is not supported for this model. Setting batch_size to 1.") + # Validate quant linear before quantization starts _ = select_quant_linear( bits=self.quantize_config.bits, From e1ef1f7cc94f27e1dc41da2f0b03aec36bd35ccf Mon Sep 17 00:00:00 2001 From: Qubitium-ModelCloud Date: Sun, 6 Apr 2025 16:40:43 +0800 Subject: [PATCH 09/14] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a1d1e317a..1d8a959e7 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@

## Latest News -* 04/5/2025 2.3.0-dev: New experimental `multi-gpu` quantization support. Reduced vram usage. +* 04/6/2025 2.3.0-dev: Prelim `Llama 4` model (text-only) support. New experimental `multi-gpu` quantization support. Reduced vram usage. * 04/2/2025 [2.2.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v2.2.0): New `Qwen 2.5 VL` model support. New `samples` log column during quantization to track module activation in MoE models. `Loss` log column now color-coded to highlight modules that are friendly/resistant to quantization. Progress (per-step) stats during quantization now streamed to log file. Auto `bfloat16` dtype loading for models based on model config. Fix kernel compile for Pytorch/ROCm. Slightly faster quantization and auto-resolve some low-level oom issues for smaller vram gpus. * 03/12/2025 [2.1.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v2.1.0): ✨ New `QQQ` quantization method and inference support! New Google `Gemma 3` zero-day model support. From fd06d6a3fcf9b0f8e4db1ea487c07526cdff684e Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sun, 6 Apr 2025 17:07:09 +0800 Subject: [PATCH 10/14] fix data --- gptqmodel/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 0f0c73003..1890b4e41 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -316,7 +316,7 @@ def _convert_tensor_to_list(tensor): ] else: new_calibration_dataset_batched = [ - {"input_ids": LongTensor(block["input_ids"]).unsqueeze(0)} + {"input_ids": torch.cat(LongTensor(block["input_ids"]), dim=0).long(),} for block in new_calibration_dataset ] From b650d16463777ef06f1836c99008da8082aff660 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sun, 6 Apr 2025 17:14:11 +0800 Subject: [PATCH 11/14] fix input_ids --- gptqmodel/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 1890b4e41..59388384b 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -316,7 +316,7 @@ def _convert_tensor_to_list(tensor): ] else: new_calibration_dataset_batched = [ - {"input_ids": torch.cat(LongTensor(block["input_ids"]), dim=0).long(),} + {"input_ids": torch.tensor(block["input_ids"], dtype=torch.long)} for block in new_calibration_dataset ] From 250060b5bdeac74e72a2528dd1d358ccac534e73 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sun, 6 Apr 2025 20:34:58 +0800 Subject: [PATCH 12/14] update llama4 modules --- gptqmodel/models/definitions/llama4.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py index 3e76f5acb..483b4d9d0 100644 --- a/gptqmodel/models/definitions/llama4.py +++ b/gptqmodel/models/definitions/llama4.py @@ -34,5 +34,7 @@ class Llama4GPTQ(BaseGPTQModel): layer_modules = [ ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"], + # feed_forward.router contains feed_forward.experts.gate_up_proj and feed_forward.experts.down_proj + ["feed_forward.router"], ["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj", "feed_forward.shared_expert.down_proj"], ] From 35266f68df6392885cbdac7d90886e118f3ab945 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sun, 6 Apr 2025 20:50:15 +0800 Subject: [PATCH 13/14] cleanup --- gptqmodel/models/definitions/llama4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py index 483b4d9d0..0f78486c9 100644 --- a/gptqmodel/models/definitions/llama4.py +++ b/gptqmodel/models/definitions/llama4.py @@ -34,7 +34,6 @@ class Llama4GPTQ(BaseGPTQModel): layer_modules = [ ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"], - # feed_forward.router contains feed_forward.experts.gate_up_proj and feed_forward.experts.down_proj ["feed_forward.router"], ["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj", "feed_forward.shared_expert.down_proj"], ] From 44fc282656ac737f2cf15590e4906d9e9728fc17 Mon Sep 17 00:00:00 2001 From: LRL-ModelCloud Date: Sun, 6 Apr 2025 21:34:20 +0800 Subject: [PATCH 14/14] Revert "update llama4 modules" This reverts commit 250060b5bdeac74e72a2528dd1d358ccac534e73. --- gptqmodel/models/definitions/llama4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py index 0f78486c9..3e76f5acb 100644 --- a/gptqmodel/models/definitions/llama4.py +++ b/gptqmodel/models/definitions/llama4.py @@ -34,6 +34,5 @@ class Llama4GPTQ(BaseGPTQModel): layer_modules = [ ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"], - ["feed_forward.router"], ["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj", "feed_forward.shared_expert.down_proj"], ]