From f6170b01e1fb936ec423f6aa0d0b15095648f867 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 6 Apr 2025 01:13:12 +0000
Subject: [PATCH 01/14] update transformers for llama 4
Signed-off-by: Qubitium
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index f5e6e7d37..ef92b2c05 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,7 +3,7 @@ datasets>=3.2.0
numpy>=1.26.4
torch>=2.2.0
safetensors>=0.5.2
-transformers>=4.49.0
+transformers>=4.51.0
threadpoolctl>=3.6.0
packaging>=24.2
device-smi==0.4.1
From bb89c7607065fb908946ac54d26d29503d815166 Mon Sep 17 00:00:00 2001
From: LRL-ModelCloud
Date: Sun, 6 Apr 2025 12:43:57 +0800
Subject: [PATCH 02/14] add Llama4GPTQ
---
gptqmodel/models/auto.py | 2 ++
gptqmodel/models/definitions/__init__.py | 1 +
gptqmodel/models/definitions/llama4.py | 37 ++++++++++++++++++++++++
tests/models/test_llama4.py | 0
4 files changed, 40 insertions(+)
create mode 100644 gptqmodel/models/definitions/llama4.py
create mode 100644 tests/models/test_llama4.py
diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py
index c2454807c..2c76de52d 100644
--- a/gptqmodel/models/auto.py
+++ b/gptqmodel/models/auto.py
@@ -89,6 +89,7 @@
from .definitions.internlm import InternLMGPTQ # noqa: E402
from .definitions.internlm2 import InternLM2GPTQ # noqa: E402
from .definitions.llama import LlamaGPTQ # noqa: E402
+from .definitions.llama4 import Llama4GPTQ # noqa: E402
from .definitions.longllama import LongLlamaGPTQ # noqa: E402
from .definitions.minicpm import MiniCPMGPTQ # noqa: E402
from .definitions.minicpm3 import MiniCPM3GPTQ # noqa: E402
@@ -128,6 +129,7 @@
"gptj": GPTJGPTQ,
"gpt2": GPT2GPTQ,
"llama": LlamaGPTQ,
+ "llama4": Llama4GPTQ,
"opt": OPTGPTQ,
"moss": MOSSGPTQ,
"chatglm": ChatGLM,
diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py
index c2dac053c..a23e87e93 100644
--- a/gptqmodel/models/definitions/__init__.py
+++ b/gptqmodel/models/definitions/__init__.py
@@ -41,6 +41,7 @@
from .internlm import InternLMGPTQ
from .internlm2 import InternLM2GPTQ
from .llama import LlamaGPTQ
+from .llama4 import Llama4GPTQ
from .longllama import LongLlamaGPTQ
from .minicpm3 import MiniCPM3GPTQ
from .mistral import MistralGPTQ
diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py
new file mode 100644
index 000000000..faa4d3d8a
--- /dev/null
+++ b/gptqmodel/models/definitions/llama4.py
@@ -0,0 +1,37 @@
+# Copyright 2024-2025 ModelCloud.ai
+# Copyright 2024-2025 qubitium@modelcloud.ai
+# Contact: qubitium@modelcloud.ai, x.com/qubitium
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseGPTQModel
+
+
+class Llama4GPTQ(BaseGPTQModel):
+ base_modules = ["language_model.model.embed_tokens", "language_model.model.norm"]
+ pre_lm_head_norm_module = "language_model.model.norm"
+
+ layers_node = "language_model.model.layers"
+ layer_type = "Llama4TextDecoderLayer"
+
+ layer_modules = [
+ ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
+ ["self_attn.o_proj"],
+
+
+ ["feed_forward.experts.gate_up_proj"],
+ ["feed_forward.experts.down_proj"],
+
+ ["feed_forward.shared_expert.down_proj"],
+ ["feed_forward.shared_expert.down_proj"],
+ ]
diff --git a/tests/models/test_llama4.py b/tests/models/test_llama4.py
new file mode 100644
index 000000000..e69de29bb
From 0ebfb175f488125f871ef4f26157056c0c96a719 Mon Sep 17 00:00:00 2001
From: LRL-ModelCloud
Date: Sun, 6 Apr 2025 14:45:59 +0800
Subject: [PATCH 03/14] use loader AutoModelForImageTextToText
---
gptqmodel/models/definitions/llama4.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py
index faa4d3d8a..07a596337 100644
--- a/gptqmodel/models/definitions/llama4.py
+++ b/gptqmodel/models/definitions/llama4.py
@@ -15,9 +15,11 @@
# limitations under the License.
from ..base import BaseGPTQModel
-
+from transformers import AutoModelForImageTextToText
class Llama4GPTQ(BaseGPTQModel):
+ loader = AutoModelForImageTextToText
+
base_modules = ["language_model.model.embed_tokens", "language_model.model.norm"]
pre_lm_head_norm_module = "language_model.model.norm"
From 006d7a6172fed9c881edfdd37b1cea88aed662c9 Mon Sep 17 00:00:00 2001
From: LRL-ModelCloud
Date: Sun, 6 Apr 2025 15:05:49 +0800
Subject: [PATCH 04/14] cleanup
---
gptqmodel/models/definitions/llama4.py | 6 +-----
1 file changed, 1 insertion(+), 5 deletions(-)
diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py
index 07a596337..eb3f0dccf 100644
--- a/gptqmodel/models/definitions/llama4.py
+++ b/gptqmodel/models/definitions/llama4.py
@@ -30,10 +30,6 @@ class Llama4GPTQ(BaseGPTQModel):
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
-
- ["feed_forward.experts.gate_up_proj"],
- ["feed_forward.experts.down_proj"],
-
- ["feed_forward.shared_expert.down_proj"],
+ ["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj"],
["feed_forward.shared_expert.down_proj"],
]
From 26f074aacd5b2aadfe33a68e9c4fd6117b936d24 Mon Sep 17 00:00:00 2001
From: ZX-ModelCloud
Date: Sun, 6 Apr 2025 15:52:04 +0800
Subject: [PATCH 05/14] fix qkvo forward when every 4 layer
Signed-off-by: ZX-ModelCloud
---
gptqmodel/models/definitions/llama4.py | 3 +--
tests/models/test_llama4.py | 29 ++++++++++++++++++++++++++
2 files changed, 30 insertions(+), 2 deletions(-)
diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py
index eb3f0dccf..76e7f64bd 100644
--- a/gptqmodel/models/definitions/llama4.py
+++ b/gptqmodel/models/definitions/llama4.py
@@ -27,8 +27,7 @@ class Llama4GPTQ(BaseGPTQModel):
layer_type = "Llama4TextDecoderLayer"
layer_modules = [
- ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
- ["self_attn.o_proj"],
+ ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"],
["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj"],
["feed_forward.shared_expert.down_proj"],
diff --git a/tests/models/test_llama4.py b/tests/models/test_llama4.py
index e69de29bb..047b8d4a4 100644
--- a/tests/models/test_llama4.py
+++ b/tests/models/test_llama4.py
@@ -0,0 +1,29 @@
+# Copyright 2024-2025 ModelCloud.ai
+# Copyright 2024-2025 qubitium@modelcloud.ai
+# Contact: qubitium@modelcloud.ai, x.com/qubitium
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from model_test import ModelTest
+
+
+class TestLlama4(ModelTest):
+ NATIVE_MODEL_ID = "/monster/data/model/Llama-4-Scout-17B-16E-Instruct" # "meta-llama/Llama-4-Scout-17B-16E-Instruct"
+ NATIVE_ARC_CHALLENGE_ACC = 0.3567
+ NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3805
+ QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.36
+ APPLY_CHAT_TEMPLATE = True
+ TRUST_REMOTE_CODE = True
+
+ def test_llama4(self):
+ self.quant_lm_eval()
From 15e8c04fafca1bf82d51531c5fdf46e881f1a16e Mon Sep 17 00:00:00 2001
From: Qubitium-ModelCloud
Date: Sun, 6 Apr 2025 16:22:20 +0800
Subject: [PATCH 06/14] Update llama4.py
---
gptqmodel/models/definitions/llama4.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py
index 76e7f64bd..d248b5684 100644
--- a/gptqmodel/models/definitions/llama4.py
+++ b/gptqmodel/models/definitions/llama4.py
@@ -29,6 +29,5 @@ class Llama4GPTQ(BaseGPTQModel):
layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"],
- ["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj"],
- ["feed_forward.shared_expert.down_proj"],
+ ["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj", "feed_forward.shared_expert.down_proj"],
]
From 10d330f54fee5a5d6e7f4d08bc8c5fbcd8d3045a Mon Sep 17 00:00:00 2001
From: LRL-ModelCloud
Date: Sun, 6 Apr 2025 16:30:11 +0800
Subject: [PATCH 07/14] add support_batch_quantize
---
gptqmodel/looper/module_looper.py | 6 ++++--
gptqmodel/models/base.py | 20 +++++++++++++++-----
gptqmodel/models/definitions/llama4.py | 7 ++++++-
3 files changed, 25 insertions(+), 8 deletions(-)
diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py
index 04d075ea7..aca3190cd 100644
--- a/gptqmodel/looper/module_looper.py
+++ b/gptqmodel/looper/module_looper.py
@@ -40,6 +40,7 @@ class ModuleLooper():
def __init__(self, model: BaseGPTQModel, processors: List[LoopProcessor]):
self.processors = processors
self.gptq_model = model
+ self.support_batch_quantize = model.support_batch_quantize
def cache_inputs(self, layers, auto_gc, calibration_data, calibration_enable_gpu_cache):
layer_inputs = []
@@ -292,7 +293,8 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal
mask = attention_masks[j]
layer_attention_mask = mask if mask is None else move_to(mask, device=cur_layer_device)
- additional_layer_inputs = {"attention_mask": layer_attention_mask}
+ additional_layer_inputs = {"attention_mask": layer_attention_mask} if self.support_batch_quantize else {}
+
layer_position_ids = (
None if not position_ids else move_to(position_ids[j], device=cur_layer_device)
)
@@ -371,7 +373,7 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal
mask = attention_masks[j]
layer_attention_mask = mask if mask is None else move_to(mask, device=cur_layer_device)
- additional_layer_inputs = {"attention_mask": layer_attention_mask}
+ additional_layer_inputs = {"attention_mask": layer_attention_mask} if self.support_batch_quantize else {}
layer_position_ids = None if not position_ids else move_to(position_ids[j], device=cur_layer_device)
if layer_position_ids is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py
index 12aa23c13..9691f3c8e 100644
--- a/gptqmodel/models/base.py
+++ b/gptqmodel/models/base.py
@@ -28,6 +28,7 @@
from packaging import version
from packaging.version import Version
from tokenicer import Tokenicer
+from torch import LongTensor
from transformers import (AutoModelForCausalLM, AutoProcessor, PreTrainedModel,
PreTrainedTokenizerBase, ProcessorMixin, modeling_utils)
@@ -123,6 +124,8 @@ class BaseGPTQModel(nn.Module):
server = None
+ support_batch_quantize = True
+
def __init__(
self,
model: PreTrainedModel,
@@ -306,11 +309,18 @@ def _convert_tensor_to_list(tensor):
new_calibration_dataset = concatenated_data
- new_calibration_dataset_batched = [
- collate_data(new_calibration_dataset[start: start + batch_size], self.tokenizer.pad_token_id)
- for start in range(0, len(new_calibration_dataset), batch_size)
- ]
-
+ if self.support_batch_quantize:
+ new_calibration_dataset_batched = [
+ collate_data(new_calibration_dataset[start: start + batch_size], self.tokenizer.pad_token_id)
+ for start in range(0, len(new_calibration_dataset), batch_size)
+ ]
+ else:
+ new_calibration_dataset_batched = [
+ {
+ "input_ids": torch.cat([LongTensor(block["input_ids"]) for block in new_calibration_dataset[start: start + batch_size]], dim=0).long(),
+ }
+ for start in range(0, len(new_calibration_dataset), batch_size)
+ ]
return new_calibration_dataset_batched
diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py
index d248b5684..3e76f5acb 100644
--- a/gptqmodel/models/definitions/llama4.py
+++ b/gptqmodel/models/definitions/llama4.py
@@ -14,10 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ..base import BaseGPTQModel
from transformers import AutoModelForImageTextToText
+from ..base import BaseGPTQModel
+
+
class Llama4GPTQ(BaseGPTQModel):
+ # some bug in the attention_mask of transformers.modeling_llama4,
+ # so batch quantization for Llama4 is temporarily not supported.
+ support_batch_quantize = False
loader = AutoModelForImageTextToText
base_modules = ["language_model.model.embed_tokens", "language_model.model.norm"]
From 4b18850a3d145e0bf49237d629b62212f0fae14e Mon Sep 17 00:00:00 2001
From: LRL-ModelCloud
Date: Sun, 6 Apr 2025 16:39:55 +0800
Subject: [PATCH 08/14] add warning
---
gptqmodel/models/base.py | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py
index 9691f3c8e..0f0c73003 100644
--- a/gptqmodel/models/base.py
+++ b/gptqmodel/models/base.py
@@ -316,10 +316,8 @@ def _convert_tensor_to_list(tensor):
]
else:
new_calibration_dataset_batched = [
- {
- "input_ids": torch.cat([LongTensor(block["input_ids"]) for block in new_calibration_dataset[start: start + batch_size]], dim=0).long(),
- }
- for start in range(0, len(new_calibration_dataset), batch_size)
+ {"input_ids": LongTensor(block["input_ids"]).unsqueeze(0)}
+ for block in new_calibration_dataset
]
return new_calibration_dataset_batched
@@ -358,6 +356,10 @@ def quantize(
"FORMAT.MARLIN is deprecated for quantization. Please switch to FORMAT.GPTQ. GPTQMOdel will auto-use Marlin kernel for accelerated inference for FORMAT.GPTQ."
)
+ if self.support_batch_quantize is False:
+ batch_size = 1
+ log.warn("Batch quantization is not supported for this model. Setting batch_size to 1.")
+
# Validate quant linear before quantization starts
_ = select_quant_linear(
bits=self.quantize_config.bits,
From e1ef1f7cc94f27e1dc41da2f0b03aec36bd35ccf Mon Sep 17 00:00:00 2001
From: Qubitium-ModelCloud
Date: Sun, 6 Apr 2025 16:40:43 +0800
Subject: [PATCH 09/14] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index a1d1e317a..1d8a959e7 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,7 @@
## Latest News
-* 04/5/2025 2.3.0-dev: New experimental `multi-gpu` quantization support. Reduced vram usage.
+* 04/6/2025 2.3.0-dev: Prelim `Llama 4` model (text-only) support. New experimental `multi-gpu` quantization support. Reduced vram usage.
* 04/2/2025 [2.2.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v2.2.0): New `Qwen 2.5 VL` model support. New `samples` log column during quantization to track module activation in MoE models. `Loss` log column now color-coded to highlight modules that are friendly/resistant to quantization. Progress (per-step) stats during quantization now streamed to log file. Auto `bfloat16` dtype loading for models based on model config. Fix kernel compile for Pytorch/ROCm. Slightly faster quantization and auto-resolve some low-level oom issues for smaller vram gpus.
* 03/12/2025 [2.1.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v2.1.0): ✨ New `QQQ` quantization method and inference support!
New Google `Gemma 3` zero-day model support.
From fd06d6a3fcf9b0f8e4db1ea487c07526cdff684e Mon Sep 17 00:00:00 2001
From: LRL-ModelCloud
Date: Sun, 6 Apr 2025 17:07:09 +0800
Subject: [PATCH 10/14] fix data
---
gptqmodel/models/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py
index 0f0c73003..1890b4e41 100644
--- a/gptqmodel/models/base.py
+++ b/gptqmodel/models/base.py
@@ -316,7 +316,7 @@ def _convert_tensor_to_list(tensor):
]
else:
new_calibration_dataset_batched = [
- {"input_ids": LongTensor(block["input_ids"]).unsqueeze(0)}
+ {"input_ids": torch.cat(LongTensor(block["input_ids"]), dim=0).long(),}
for block in new_calibration_dataset
]
From b650d16463777ef06f1836c99008da8082aff660 Mon Sep 17 00:00:00 2001
From: LRL-ModelCloud
Date: Sun, 6 Apr 2025 17:14:11 +0800
Subject: [PATCH 11/14] fix input_ids
---
gptqmodel/models/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py
index 1890b4e41..59388384b 100644
--- a/gptqmodel/models/base.py
+++ b/gptqmodel/models/base.py
@@ -316,7 +316,7 @@ def _convert_tensor_to_list(tensor):
]
else:
new_calibration_dataset_batched = [
- {"input_ids": torch.cat(LongTensor(block["input_ids"]), dim=0).long(),}
+ {"input_ids": torch.tensor(block["input_ids"], dtype=torch.long)}
for block in new_calibration_dataset
]
From 250060b5bdeac74e72a2528dd1d358ccac534e73 Mon Sep 17 00:00:00 2001
From: LRL-ModelCloud
Date: Sun, 6 Apr 2025 20:34:58 +0800
Subject: [PATCH 12/14] update llama4 modules
---
gptqmodel/models/definitions/llama4.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py
index 3e76f5acb..483b4d9d0 100644
--- a/gptqmodel/models/definitions/llama4.py
+++ b/gptqmodel/models/definitions/llama4.py
@@ -34,5 +34,7 @@ class Llama4GPTQ(BaseGPTQModel):
layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"],
+ # feed_forward.router contains feed_forward.experts.gate_up_proj and feed_forward.experts.down_proj
+ ["feed_forward.router"],
["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj", "feed_forward.shared_expert.down_proj"],
]
From 35266f68df6392885cbdac7d90886e118f3ab945 Mon Sep 17 00:00:00 2001
From: LRL-ModelCloud
Date: Sun, 6 Apr 2025 20:50:15 +0800
Subject: [PATCH 13/14] cleanup
---
gptqmodel/models/definitions/llama4.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py
index 483b4d9d0..0f78486c9 100644
--- a/gptqmodel/models/definitions/llama4.py
+++ b/gptqmodel/models/definitions/llama4.py
@@ -34,7 +34,6 @@ class Llama4GPTQ(BaseGPTQModel):
layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"],
- # feed_forward.router contains feed_forward.experts.gate_up_proj and feed_forward.experts.down_proj
["feed_forward.router"],
["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj", "feed_forward.shared_expert.down_proj"],
]
From 44fc282656ac737f2cf15590e4906d9e9728fc17 Mon Sep 17 00:00:00 2001
From: LRL-ModelCloud
Date: Sun, 6 Apr 2025 21:34:20 +0800
Subject: [PATCH 14/14] Revert "update llama4 modules"
This reverts commit 250060b5bdeac74e72a2528dd1d358ccac534e73.
---
gptqmodel/models/definitions/llama4.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py
index 0f78486c9..3e76f5acb 100644
--- a/gptqmodel/models/definitions/llama4.py
+++ b/gptqmodel/models/definitions/llama4.py
@@ -34,6 +34,5 @@ class Llama4GPTQ(BaseGPTQModel):
layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"],
- ["feed_forward.router"],
["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj", "feed_forward.shared_expert.down_proj"],
]