Skip to content

Llama 4 Support #1508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
from .definitions.internlm import InternLMGPTQ # noqa: E402
from .definitions.internlm2 import InternLM2GPTQ # noqa: E402
from .definitions.llama import LlamaGPTQ # noqa: E402
from .definitions.llama4 import Llama4GPTQ # noqa: E402
from .definitions.longllama import LongLlamaGPTQ # noqa: E402
from .definitions.mimo import MimoGPTQ # noqa: E402
from .definitions.minicpm import MiniCPMGPTQ # noqa: E402
Expand Down Expand Up @@ -144,6 +145,7 @@
"gptj": GPTJGPTQ,
"gpt2": GPT2GPTQ,
"llama": LlamaGPTQ,
"llama4": Llama4GPTQ,
"opt": OPTGPTQ,
"moss": MOSSGPTQ,
"chatglm": ChatGLM,
Expand Down
7 changes: 7 additions & 0 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch.nn as nn
from packaging import version
from tokenicer import Tokenicer
from torch import LongTensor
from transformers import (AutoModelForCausalLM, AutoProcessor, PreTrainedModel,
PreTrainedTokenizerBase, ProcessorMixin, modeling_utils)

Expand Down Expand Up @@ -124,6 +125,8 @@ class BaseGPTQModel(nn.Module):

server = None

support_batch_quantize = True

def __init__(
self,
model: PreTrainedModel,
Expand Down Expand Up @@ -371,6 +374,10 @@ def quantize(
"FORMAT.MARLIN is deprecated for quantization. Please switch to FORMAT.GPTQ. GPTQMOdel will auto-use Marlin kernel for accelerated inference for FORMAT.GPTQ."
)

if self.support_batch_quantize is False:
batch_size = 1
log.warn("Batch quantization is not supported for this model. Setting batch_size to 1.")

# Validate quant linear before quantization starts
_ = select_quant_linear(
bits=self.quantize_config.bits,
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/definitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .internlm import InternLMGPTQ
from .internlm2 import InternLM2GPTQ
from .llama import LlamaGPTQ
from .llama4 import Llama4GPTQ
from .longllama import LongLlamaGPTQ
from .mimo import MimoGPTQ
from .minicpm3 import MiniCPM3GPTQ
Expand Down
38 changes: 38 additions & 0 deletions gptqmodel/models/definitions/llama4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024-2025 ModelCloud.ai
# Copyright 2024-2025 qubitium@modelcloud.ai
# Contact: qubitium@modelcloud.ai, x.com/qubitium
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers import AutoModelForImageTextToText

from ..base import BaseGPTQModel


class Llama4GPTQ(BaseGPTQModel):
# some bug in the attention_mask of transformers.modeling_llama4,
# so batch quantization for Llama4 is temporarily not supported.
support_batch_quantize = False
loader = AutoModelForImageTextToText

base_modules = ["language_model.model.embed_tokens", "language_model.model.norm"]
pre_lm_head_norm_module = "language_model.model.norm"

layers_node = "language_model.model.layers"
layer_type = "Llama4TextDecoderLayer"

layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"],

["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj", "feed_forward.shared_expert.down_proj"],
]
29 changes: 29 additions & 0 deletions tests/models/test_llama4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2024-2025 ModelCloud.ai
# Copyright 2024-2025 qubitium@modelcloud.ai
# Contact: qubitium@modelcloud.ai, x.com/qubitium
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from model_test import ModelTest


class TestLlama4(ModelTest):
NATIVE_MODEL_ID = "/monster/data/model/Llama-4-Scout-17B-16E-Instruct" # "meta-llama/Llama-4-Scout-17B-16E-Instruct"
NATIVE_ARC_CHALLENGE_ACC = 0.3567
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3805
QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.36
APPLY_CHAT_TEMPLATE = True
TRUST_REMOTE_CODE = True

def test_llama4(self):
self.quant_lm_eval()