-
Notifications
You must be signed in to change notification settings - Fork 113
Description
Describe the bug
When trying to quantize Gemma3 4b
# Copyright 2024-2025 ModelCloud.ai
# Copyright 2024-2025 qubitium@modelcloud.ai
# Contact: qubitium@modelcloud.ai, x.com/qubitium
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import load_dataset
from gptqmodel import GPTQModel, QuantizeConfig
from transformers import AutoTokenizer
pretrained_model_id = "unsloth/gemma-3-4b-it" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quantized_model_id = "gemma-3-4b-gptq"
# os.makedirs(quantized_model_dir, exist_ok=True)
def get_wikitext2(tokenizer, nsamples, seqlen):
traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train").filter(
lambda x: len(x["text"]) >= seqlen)
return [tokenizer(example["text"]) for example in traindata.select(range(nsamples))]
@torch.no_grad()
def calculate_avg_ppl(model, tokenizer):
from gptqmodel.utils import Perplexity
ppl = Perplexity(
model=model,
tokenizer=tokenizer,
dataset_path="wikitext",
dataset_name="wikitext-2-raw-v1",
split="train",
text_column="text",
)
all = ppl.calculate(n_ctx=512, n_batch=512)
# average ppl
avg = sum(all) / len(all)
return avg
def main():
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_id, use_fast=True)
traindataset = get_wikitext2(tokenizer, nsamples=256, seqlen=1024)
quantize_config = QuantizeConfig(
bits=4, # quantize model to 4-bit
group_size=128, # it is recommended to set the value to 128
)
# load un-quantized model, the model will always be force loaded into cpu
model = GPTQModel.load(pretrained_model_id, quantize_config)
# quantize model, the calibration_dataset should be list of dict whose keys can only be "input_ids" and "attention_mask"
# with value under torch.LongTensor type.
model.quantize(traindataset)
# save quantized model using safetensors
model.save(quantized_model_id)
if __name__ == "__main__":
import logging
logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)
main()
it outputs a config like so
{
"architectures": [
"Gemma3ForConditionalGeneration"
],
"boi_token_index": 255999,
"bos_token_id": 2,
"eoi_token_index": 256000,
"eos_token_id": 106,
"image_token_index": 262144,
"initializer_range": 0.02,
"mm_tokens_per_image": 256,
"model_type": "gemma3",
"pad_token_id": 0,
"quantization_config": {
"bits": 4,
"checkpoint_format": "gptq",
"desc_act": true,
"group_size": 128,
"lm_head": false,
"meta": {
"damp_auto_increment": 0.01,
"damp_percent": 0.05,
"mse": 0.0,
"quantizer": [
"gptqmodel:4.0.0-dev"
],
"static_groups": false,
"true_sequential": true,
"uri": "https://github.com/modelcloud/gptqmodel",
"v2": false,
"v2_alpha": 0.25
},
"pack_dtype": "int32",
"quant_method": "gptq",
"sym": true
},
"text_config": {
"attention_bias": false,
"attention_dropout": 0.0,
"attn_logit_softcapping": null,
"cache_implementation": "hybrid",
"final_logit_softcapping": null,
"head_dim": 256,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2560,
"initializer_range": 0.02,
"intermediate_size": 10240,
"layer_types": [
"sliding_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"full_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention",
"sliding_attention"
],
"max_position_embeddings": 131072,
"model_type": "gemma3_text",
"num_attention_heads": 8,
"num_hidden_layers": 34,
"num_key_value_heads": 4,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_local_base_freq": 10000.0,
"rope_scaling": {
"factor": 8.0,
"rope_type": "linear"
},
"rope_theta": 1000000.0,
"sliding_window": 1024,
"sliding_window_pattern": 6,
"torch_dtype": "bfloat16",
"use_cache": true,
"vocab_size": 262208
},
"torch_dtype": "bfloat16",
"transformers_version": "4.53.1",
"unsloth_fixed": true,
"use_cache": false,
"vision_config": {
"attention_dropout": 0.0,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 896,
"intermediate_size": 4304,
"layer_norm_eps": 1e-06,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 27,
"patch_size": 14,
"torch_dtype": "bfloat16",
"vision_use_head": false
}
}
However, when loading the model in vllm I see the following which indicates that the model itself doesn't include vision despite the config saying it does
INFO 07-09 17:49:10 [weight_utils.py:345] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
ERROR 07-09 17:49:19 [core.py:586] EngineCore failed to start.
ERROR 07-09 17:49:19 [core.py:586] Traceback (most recent call last):
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 577, in run_engine_core
ERROR 07-09 17:49:19 [core.py:586] engine_core = EngineCoreProc(*args, **kwargs)
ERROR 07-09 17:49:19 [core.py:586] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 404, in __init__
Process EngineCore_0:
ERROR 07-09 17:49:19 [core.py:586] super().__init__(vllm_config, executor_class, log_stats,
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 75, in __init__
ERROR 07-09 17:49:19 [core.py:586] self.model_executor = executor_class(vllm_config)
ERROR 07-09 17:49:19 [core.py:586] ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 53, in __init__
ERROR 07-09 17:49:19 [core.py:586] self._init_executor()
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 48, in _init_executor
ERROR 07-09 17:49:19 [core.py:586] self.collective_rpc("load_model")
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
ERROR 07-09 17:49:19 [core.py:586] answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 07-09 17:49:19 [core.py:586] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/utils/__init__.py", line 2736, in run_method
ERROR 07-09 17:49:19 [core.py:586] return func(*args, **kwargs)
ERROR 07-09 17:49:19 [core.py:586] ^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 185, in load_model
ERROR 07-09 17:49:19 [core.py:586] self.model_runner.load_model()
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1776, in load_model
ERROR 07-09 17:49:19 [core.py:586] self.model = model_loader.load_model(
ERROR 07-09 17:49:19 [core.py:586] ^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/base_loader.py", line 41, in load_model
ERROR 07-09 17:49:19 [core.py:586] self.load_weights(model, model_config)
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/default_loader.py", line 269, in load_weights
ERROR 07-09 17:49:19 [core.py:586] loaded_weights = model.load_weights(
ERROR 07-09 17:49:19 [core.py:586] ^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/gemma3_mm.py", line 720, in load_weights
ERROR 07-09 17:49:19 [core.py:586] return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
ERROR 07-09 17:49:19 [core.py:586] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 291, in load_weights
ERROR 07-09 17:49:19 [core.py:586] autoloaded_weights = set(self._load_module("", self.module, weights))
ERROR 07-09 17:49:19 [core.py:586] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 249, in _load_module
ERROR 07-09 17:49:19 [core.py:586] yield from self._load_module(prefix,
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 222, in _load_module
ERROR 07-09 17:49:19 [core.py:586] loaded_params = module_load_weights(weights)
ERROR 07-09 17:49:19 [core.py:586] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/siglip.py", line 514, in load_weights
ERROR 07-09 17:49:19 [core.py:586] param = params_dict[name]
ERROR 07-09 17:49:19 [core.py:586] ~~~~~~~~~~~^^^^^^
ERROR 07-09 17:49:19 [core.py:586] KeyError: 'vision_model.encoder.layers.0.self_attn.qkv_proj.weight'
It is also missing several of the files related to the multimodal features and other features. Mainly preprocessor_config.json
, processor_config.json
, model.safetensors.index.json
, tokenizer.model
, added_tokens.json
, chat_template.json
. I also noticed that the special_tokens_map.json
differs from the base one which seems odd.
GPU Info
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A100 80GB PCIe Off | 00000001:00:00.0 Off | 0 |
| N/A 38C P0 46W / 300W | 6MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
Software Info
Distributor ID: Ubuntu
Description: Ubuntu 24.04.2 LTS
Release: 24.04
Codename: noble
Name: accelerate
Version: 1.8.1
Location: /home/ubuntu/grouping-fine-tune/quantize/.venv/lib/python3.12/site-packages
Requires: huggingface-hub, numpy, packaging, psutil, pyyaml, safetensors, torch
Required-by: gptqmodel
---
Name: gptqmodel
Version: 4.0.0.dev0
Location: /home/ubuntu/grouping-fine-tune/quantize/.venv/lib/python3.12/site-packages
Requires: accelerate, device-smi, hf-transfer, huggingface-hub, logbar, numpy, packaging, pillow, protobuf, random-word, safetensors, soundfile, threadpoolctl, tokenicer, torch, transformers
Required-by:
---
Name: torch
Version: 2.7.1
Location: /home/ubuntu/grouping-fine-tune/quantize/.venv/lib/python3.12/site-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-cufile-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-cusparselt-cu12, nvidia-nccl-cu12, nvidia-nvjitlink-cu12, nvidia-nvtx-cu12, setuptools, sympy, triton, typing-extensions
Required-by: accelerate, gptqmodel
---
Name: transformers
Version: 4.52.4
Location: /home/ubuntu/grouping-fine-tune/quantize/.venv/lib/python3.12/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: gptqmodel, tokenicer
---
Name: triton
Version: 3.3.1
Location: /home/ubuntu/grouping-fine-tune/quantize/.venv/lib/python3.12/site-packages
Requires: setuptools
Required-by: torch
If you are reporting an inference bug of a post-quantized model, please post the content of config.json
and quantize_config.json
.
To Reproduce
Run the code above
Expected behavior
Either the multi-modal layers should be quantized too or the config should not say it's multi-modal. Additionally, (not sure if this should be a separate issue, but the files that were mentioned earlier like preprocessor_config.json
should be present and/or correct.
Model/Datasets