diff --git a/README.md b/README.md index c229b1a30..2a59c2d60 100644 --- a/README.md +++ b/README.md @@ -174,18 +174,17 @@ Native support support some of the most popular multi-modal models: ## Model Support | Model | | | | | | | | | | |-------------------|---|-------------------|---|----------------|---|----------------|---|------------|---| -| Baichuan | ✅ | EXAONE 3.0 | ✅ | InternLM 1/2.5 | ✅ | OPT | ✅ | StableLM | ✅ | -| Bloom | ✅ | Falcon (H1) | ✅ | Llama 1-3.3 | ✅ | OLMo2 | ✅ | StarCoder2 | ✅ | -| ChatGLM | ✅ | Gemma 1/2/3 | ✅ | Llama 3.2 VL | ✅ | Ovis 1.6/2 | ✅ | TeleChat2 | ✅ | -| CodeGen | ✅ | GPTBigCod | ✅ | LongLLaMA | ✅ | Phi 1-4 | ✅ | Yi | ✅ | -| Cohere 1-2 | ✅ | GPTQ-Neo/GPT-NeoX | ✅ | Instella | ✅ | Nemotron Ultra | ✅ | Seed-OSS | ✅ | -| DBRX Converted | ✅ | GPT-2 | ✅ | MiniCPM3 | ✅ | PanGu-α | ✅ | XVERSE | ✅ | -| Deci | ✅ | GPT-J | ✅ | Mistral | ✅ | Qwen 1/2/3 | ✅ | | | -| DeepSeek-V2/V3/R1 | ✅ | GPT-OSS | ✅ | Mixtral | ✅ | Qwen 2/3 MoE | ✅ | | | -| DeepSeek-V2-Lite | ✅ | Granite | ✅ | MobileLLM | ✅ | Qwen 2/2.5 VL | ✅ | | | -| Dream | ✅ | GRIN-MoE | ✅ | MOSS | ✅ | Qwen 2.5 Omni | ✅ | | | -| ERNIE 4.5 | ✅ | Hymba | ✅ | MPT | ✅ | RefinedWeb | ✅ | | | - +| Baichuan | ✅ | EXAONE 3.0 | ✅ | InternLM 1/2.5 | ✅ | MPT | ✅ | RefinedWeb | ✅ | +| Bloom | ✅ | Falcon (H1) | ✅ | Llama 1-3.3 | ✅ | OPT | ✅ | StableLM | ✅ | +| ChatGLM | ✅ | Gemma 1/2/3 | ✅ | Llama 3.2 VL | ✅ | OLMo2 | ✅ | StarCoder2 | ✅ | +| CodeGen | ✅ | GPTBigCod | ✅ | Llama 4 | ✅ | Ovis 1.6/2 | ✅ | TeleChat2 | ✅ | +| Cohere 1-2 | ✅ | GPTQ-Neo/GPT-NeoX | ✅ | LongLLaMA | ✅ | Phi 1-4 | ✅ | Yi | ✅ | +| DBRX Converted | ✅ | GPT-2 | ✅ | Instella | ✅ | Nemotron Ultra | ✅ | Seed-OSS | ✅ | +| Deci | ✅ | GPT-J | ✅ | MiniCPM3 | ✅ | PanGu-α | ✅ | XVERSE | ✅ | +| DeepSeek-V2/V3/R1 | ✅ | GPT-OSS | ✅ | Mistral | ✅ | Qwen 1/2/3 | ✅ | | | +| DeepSeek-V2-Lite | ✅ | Granite | ✅ | Mixtral | ✅ | Qwen 2/3 MoE | ✅ | | | +| Dream | ✅ | GRIN-MoE | ✅ | MobileLLM | ✅ | Qwen 2/2.5 VL | ✅ | | | +| ERNIE 4.5 | ✅ | Hymba | ✅ | MOSS | ✅ | Qwen 2.5 Omni | ✅ | | | ## Platform and HW Support diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 26ca9f99d..cc934b8ee 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -196,7 +196,10 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal # dynamic expert layer index for model defs if self.gptq_model.dynamic_expert_index is not None: - num_experts = getattr(self.gptq_model.model.config, self.gptq_model.dynamic_expert_index) + if hasattr(self.gptq_model.model.config, "text_config"): + num_experts = getattr(self.gptq_model.model.config.text_config, self.gptq_model.dynamic_expert_index) + else: + num_experts = getattr(self.gptq_model.model.config, self.gptq_model.dynamic_expert_index) layer_modules = get_moe_layer_modules(layer_modules=self.gptq_model.layer_modules, num_experts=num_experts) diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 4bd2759e8..1713bd563 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -99,6 +99,7 @@ from .definitions.internlm import InternLMGPTQ # noqa: E402 from .definitions.internlm2 import InternLM2GPTQ # noqa: E402 from .definitions.llama import LlamaGPTQ # noqa: E402 +from .definitions.llama4 import Llama4GPTQ # noqa: E402 from .definitions.longllama import LongLlamaGPTQ # noqa: E402 from .definitions.mimo import MimoGPTQ # noqa: E402 from .definitions.minicpm import MiniCPMGPTQ # noqa: E402 @@ -145,6 +146,7 @@ "gptj": GPTJGPTQ, "gpt2": GPT2GPTQ, "llama": LlamaGPTQ, + "llama4": Llama4GPTQ, "opt": OPTGPTQ, "moss": MOSSGPTQ, "chatglm": ChatGLM, diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index da73540c0..9aaaffa18 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -26,6 +26,7 @@ import torch._dynamo import torch.nn as nn from tokenicer import Tokenicer +from torch import LongTensor from transformers import (AutoModelForCausalLM, AutoProcessor, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, modeling_utils) @@ -123,6 +124,8 @@ class BaseGPTQModel(nn.Module): server = None + support_batch_quantize = True + def __init__( self, model: PreTrainedModel, @@ -370,6 +373,10 @@ def quantize( "FORMAT.MARLIN is deprecated for quantization. Please switch to FORMAT.GPTQ. GPTQMOdel will auto-use Marlin kernel for accelerated inference for FORMAT.GPTQ." ) + if self.support_batch_quantize is False: + batch_size = 1 + log.warn("Batch quantization is not supported for this model. Setting batch_size to 1.") + # Validate quant linear before quantization starts _ = select_quant_linear( bits=self.quantize_config.bits, diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py index e96a8f5d3..95a4486a0 100644 --- a/gptqmodel/models/definitions/__init__.py +++ b/gptqmodel/models/definitions/__init__.py @@ -50,6 +50,7 @@ from .internlm import InternLMGPTQ from .internlm2 import InternLM2GPTQ from .llama import LlamaGPTQ +from .llama4 import Llama4GPTQ from .longllama import LongLlamaGPTQ from .mimo import MimoGPTQ from .minicpm3 import MiniCPM3GPTQ diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py new file mode 100644 index 000000000..178239e12 --- /dev/null +++ b/gptqmodel/models/definitions/llama4.py @@ -0,0 +1,172 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoModelForImageTextToText +from .._const import EXPERT_INDEX_PLACEHOLDER +from ..base import BaseGPTQModel + +class Llama4GPTQ(BaseGPTQModel): + # some bug in the attention_mask of transformers.modeling_llama4, + # so batch quantization for Llama4 is temporarily not supported. + support_batch_quantize = False + loader = AutoModelForImageTextToText + + base_modules = ["language_model.model.embed_tokens", "language_model.model.norm"] + pre_lm_head_norm_module = "language_model.model.norm" + + layers_node = "language_model.model.layers" + layer_type = "Llama4TextDecoderLayer" + + dynamic_expert_index = "num_local_experts" + + layer_modules = [ + ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"], + + [f"feed_forward.experts.{EXPERT_INDEX_PLACEHOLDER}.gate_proj", f"feed_forward.experts.{EXPERT_INDEX_PLACEHOLDER}.up_proj"], + [f"feed_forward.experts.{EXPERT_INDEX_PLACEHOLDER}.down_proj"], + + ["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj", "feed_forward.shared_expert.down_proj"], + ] + + def before_model_load(self, load_quantized_model=False): + if load_quantized_model: + import torch + import torch.nn as nn + import transformers.models.llama4.modeling_llama4 as llama4_modeling + from transformers.integrations.hub_kernels import use_kernel_forward_from_hub + + @use_kernel_forward_from_hub("Llama4TextMoe") + class SequentialLlama4TextMoe(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.top_k = config.num_experts_per_tok + self.hidden_dim = config.hidden_size + print(config) + self.num_experts = 16 + self.experts = nn.ModuleList( + [llama4_modeling.Llama4TextMLP(config) for _ in range(self.num_experts)] + ) + self.router = llama4_modeling.Llama4Router(config) + self.shared_expert = llama4_modeling.Llama4TextMLP(config) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = self.router(hidden_states) + if isinstance(router_logits, tuple): + router_scores, router_logits = router_logits + router_scores = router_scores.t() + else: + # transformers < 4.54.0 only returns router_logits + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) + + router_scores = ( + torch.full_like(router_logits, float("-inf")) + .scatter_(1, router_indices, router_top_value) + .transpose(0, 1) + ) + router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) + + out = self.shared_expert(hidden_states) + for i in range(self.num_experts): + out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) + + return out, router_logits + + llama4_modeling.Llama4TextMoe = SequentialLlama4TextMoe + + + def after_model_load(self, model, load_quantized_model=False): + if load_quantized_model: + return model + + import os + import torch + from concurrent.futures import ThreadPoolExecutor + from functools import partial + from transformers.modeling_utils import no_init_weights + from transformers.models.llama4.modeling_llama4 import Llama4TextMLP, Llama4TextMoe + + # adapted/modified from https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py + class SequentialLlama4TextExperts(torch.nn.ModuleList): + def __init__(self, config, original): + self.num_experts = original.gate_up_proj.shape[0] + with no_init_weights(): + super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)]) + intermediate_size = original.down_proj.shape[1] + + with torch.no_grad(): + # Batch process all expert parameters to avoid loops + gate_up_batch = torch.stack([original.gate_up_proj[i] for i in range(self.num_experts)]) + down_batch = torch.stack([original.down_proj[i] for i in range(self.num_experts)]) + + # Batch split and transpose + gate_batch = gate_up_batch[:, :, :intermediate_size].transpose(-2, -1).contiguous() + up_batch = gate_up_batch[:, :, intermediate_size:].transpose(-2, -1).contiguous() + down_batch = down_batch.transpose(-2, -1).contiguous() + + # Batch assignment + for i in range(self.num_experts): + self[i].gate_proj.weight.data = gate_batch[i] + self[i].up_proj.weight.data = up_batch[i] + self[i].down_proj.weight.data = down_batch[i] + + class SequentialLlama4TextMoe(torch.nn.Module): + def __init__(self, config, original): + super().__init__() + self.top_k = config.num_experts_per_tok + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.experts = SequentialLlama4TextExperts(config, original.experts) + self.router = original.router + self.shared_expert = original.shared_expert + + def forward(self, hidden_states: torch.Tensor): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = self.router(hidden_states) + if isinstance(router_logits, tuple): + router_scores, router_logits = router_logits + router_scores = router_scores.t() + else: + # transformers < 4.54.0 only returns router_logits + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) + + router_scores = ( + torch.full_like(router_logits, float("-inf")) + .scatter_(1, router_indices, router_top_value) + .transpose(0, 1) + ) + router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) + + out = self.shared_expert(hidden_states) + for i in range(self.num_experts): + out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) + + return out, router_logits + + model = model.to("cpu") + def process_module(name, module, model, config): + if isinstance(module, Llama4TextMoe): + new_module = SequentialLlama4TextMoe(config=config, original=module) + parent, child = name.rsplit(".", maxsplit=1) + print("replace moe" + name + child) + parent = model.get_submodule(parent) + setattr(parent, child, new_module) + print("cpu count", os.cpu_count()) + with ThreadPoolExecutor(max_workers=8) as executor: + process_fn = partial(process_module, model=model, config=model.config.get_text_config()) + list(executor.map(lambda x: process_fn(x[0], x[1]), model.named_modules())) + + return model diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index c68261162..edd63149d 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -478,7 +478,10 @@ def skip(*args, **kwargs): model.checkpoint_file_name = model_save_name if cls.dynamic_expert_index is not None: - num_experts = getattr(config, cls.dynamic_expert_index) + if hasattr(config, "text_config"): + num_experts = getattr(config.text_config, cls.dynamic_expert_index) + else: + num_experts = getattr(config, cls.dynamic_expert_index) cls.layer_modules = get_moe_layer_modules(layer_modules=cls.layer_modules, num_experts=num_experts) diff --git a/tests/models/test_llama4.py b/tests/models/test_llama4.py new file mode 100644 index 000000000..43a9f4d30 --- /dev/null +++ b/tests/models/test_llama4.py @@ -0,0 +1,29 @@ +# Copyright 2024-2025 ModelCloud.ai +# Copyright 2024-2025 qubitium@modelcloud.ai +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from model_test import ModelTest + + +class TestLlama4(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Llama-4-Scout-17B-16E-Instruct" # "meta-llama/Llama-4-Scout-17B-16E-Instruct" + NATIVE_ARC_CHALLENGE_ACC = 0.3567 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3805 + QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.36 + APPLY_CHAT_TEMPLATE = True + TRUST_REMOTE_CODE = False + + def test_llama4(self): + self.quant_lm_eval()