Skip to content

[BUG] Gemma 3 4b quantization outputs multi-modal config despite it being text only #1655

@ethan-tonic

Description

@ethan-tonic

Describe the bug

When trying to quantize Gemma3 4b

# Copyright 2024-2025 ModelCloud.ai
# Copyright 2024-2025 qubitium@modelcloud.ai
# Contact: qubitium@modelcloud.ai, x.com/qubitium
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from datasets import load_dataset
from gptqmodel import GPTQModel, QuantizeConfig
from transformers import AutoTokenizer

pretrained_model_id = "unsloth/gemma-3-4b-it" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quantized_model_id = "gemma-3-4b-gptq"


# os.makedirs(quantized_model_dir, exist_ok=True)
def get_wikitext2(tokenizer, nsamples, seqlen):
    traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train").filter(
        lambda x: len(x["text"]) >= seqlen)

    return [tokenizer(example["text"]) for example in traindata.select(range(nsamples))]


@torch.no_grad()
def calculate_avg_ppl(model, tokenizer):
    from gptqmodel.utils import Perplexity

    ppl = Perplexity(
        model=model,
        tokenizer=tokenizer,
        dataset_path="wikitext",
        dataset_name="wikitext-2-raw-v1",
        split="train",
        text_column="text",
    )

    all = ppl.calculate(n_ctx=512, n_batch=512)

    # average ppl
    avg = sum(all) / len(all)

    return avg

def main():
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_id, use_fast=True)

    traindataset = get_wikitext2(tokenizer, nsamples=256, seqlen=1024)

    quantize_config = QuantizeConfig(
        bits=4,  # quantize model to 4-bit
        group_size=128,  # it is recommended to set the value to 128
    )

    # load un-quantized model, the model will always be force loaded into cpu
    model = GPTQModel.load(pretrained_model_id, quantize_config)

    # quantize model, the calibration_dataset should be list of dict whose keys can only be "input_ids" and "attention_mask"
    # with value under torch.LongTensor type.
    model.quantize(traindataset)

    # save quantized model using safetensors
    model.save(quantized_model_id)

if __name__ == "__main__":
    import logging

    logging.basicConfig(
        format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
        level=logging.INFO,
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    main()

it outputs a config like so

{
  "architectures": [
    "Gemma3ForConditionalGeneration"
  ],
  "boi_token_index": 255999,
  "bos_token_id": 2,
  "eoi_token_index": 256000,
  "eos_token_id": 106,
  "image_token_index": 262144,
  "initializer_range": 0.02,
  "mm_tokens_per_image": 256,
  "model_type": "gemma3",
  "pad_token_id": 0,
  "quantization_config": {
    "bits": 4,
    "checkpoint_format": "gptq",
    "desc_act": true,
    "group_size": 128,
    "lm_head": false,
    "meta": {
      "damp_auto_increment": 0.01,
      "damp_percent": 0.05,
      "mse": 0.0,
      "quantizer": [
        "gptqmodel:4.0.0-dev"
      ],
      "static_groups": false,
      "true_sequential": true,
      "uri": "https://github.com/modelcloud/gptqmodel",
      "v2": false,
      "v2_alpha": 0.25
    },
    "pack_dtype": "int32",
    "quant_method": "gptq",
    "sym": true
  },
  "text_config": {
    "attention_bias": false,
    "attention_dropout": 0.0,
    "attn_logit_softcapping": null,
    "cache_implementation": "hybrid",
    "final_logit_softcapping": null,
    "head_dim": 256,
    "hidden_activation": "gelu_pytorch_tanh",
    "hidden_size": 2560,
    "initializer_range": 0.02,
    "intermediate_size": 10240,
    "layer_types": [
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "full_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "full_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "full_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "full_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "full_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention",
      "sliding_attention"
    ],
    "max_position_embeddings": 131072,
    "model_type": "gemma3_text",
    "num_attention_heads": 8,
    "num_hidden_layers": 34,
    "num_key_value_heads": 4,
    "query_pre_attn_scalar": 256,
    "rms_norm_eps": 1e-06,
    "rope_local_base_freq": 10000.0,
    "rope_scaling": {
      "factor": 8.0,
      "rope_type": "linear"
    },
    "rope_theta": 1000000.0,
    "sliding_window": 1024,
    "sliding_window_pattern": 6,
    "torch_dtype": "bfloat16",
    "use_cache": true,
    "vocab_size": 262208
  },
  "torch_dtype": "bfloat16",
  "transformers_version": "4.53.1",
  "unsloth_fixed": true,
  "use_cache": false,
  "vision_config": {
    "attention_dropout": 0.0,
    "hidden_act": "gelu_pytorch_tanh",
    "hidden_size": 1152,
    "image_size": 896,
    "intermediate_size": 4304,
    "layer_norm_eps": 1e-06,
    "model_type": "siglip_vision_model",
    "num_attention_heads": 16,
    "num_channels": 3,
    "num_hidden_layers": 27,
    "patch_size": 14,
    "torch_dtype": "bfloat16",
    "vision_use_head": false
  }
}

However, when loading the model in vllm I see the following which indicates that the model itself doesn't include vision despite the config saying it does

INFO 07-09 17:49:10 [weight_utils.py:345] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
ERROR 07-09 17:49:19 [core.py:586] EngineCore failed to start.
ERROR 07-09 17:49:19 [core.py:586] Traceback (most recent call last):
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 577, in run_engine_core
ERROR 07-09 17:49:19 [core.py:586]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 07-09 17:49:19 [core.py:586]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 404, in __init__
Process EngineCore_0:
ERROR 07-09 17:49:19 [core.py:586]     super().__init__(vllm_config, executor_class, log_stats,
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 75, in __init__
ERROR 07-09 17:49:19 [core.py:586]     self.model_executor = executor_class(vllm_config)
ERROR 07-09 17:49:19 [core.py:586]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 53, in __init__
ERROR 07-09 17:49:19 [core.py:586]     self._init_executor()
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 48, in _init_executor
ERROR 07-09 17:49:19 [core.py:586]     self.collective_rpc("load_model")
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/uniproc_executor.py", line 57, in collective_rpc
ERROR 07-09 17:49:19 [core.py:586]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 07-09 17:49:19 [core.py:586]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils/__init__.py", line 2736, in run_method
ERROR 07-09 17:49:19 [core.py:586]     return func(*args, **kwargs)
ERROR 07-09 17:49:19 [core.py:586]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 185, in load_model
ERROR 07-09 17:49:19 [core.py:586]     self.model_runner.load_model()
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1776, in load_model
ERROR 07-09 17:49:19 [core.py:586]     self.model = model_loader.load_model(
ERROR 07-09 17:49:19 [core.py:586]                  ^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/base_loader.py", line 41, in load_model
ERROR 07-09 17:49:19 [core.py:586]     self.load_weights(model, model_config)
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/default_loader.py", line 269, in load_weights
ERROR 07-09 17:49:19 [core.py:586]     loaded_weights = model.load_weights(
ERROR 07-09 17:49:19 [core.py:586]                      ^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/gemma3_mm.py", line 720, in load_weights
ERROR 07-09 17:49:19 [core.py:586]     return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
ERROR 07-09 17:49:19 [core.py:586]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 291, in load_weights
ERROR 07-09 17:49:19 [core.py:586]     autoloaded_weights = set(self._load_module("", self.module, weights))
ERROR 07-09 17:49:19 [core.py:586]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 249, in _load_module
ERROR 07-09 17:49:19 [core.py:586]     yield from self._load_module(prefix,
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/utils.py", line 222, in _load_module
ERROR 07-09 17:49:19 [core.py:586]     loaded_params = module_load_weights(weights)
ERROR 07-09 17:49:19 [core.py:586]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 07-09 17:49:19 [core.py:586]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/siglip.py", line 514, in load_weights
ERROR 07-09 17:49:19 [core.py:586]     param = params_dict[name]
ERROR 07-09 17:49:19 [core.py:586]             ~~~~~~~~~~~^^^^^^
ERROR 07-09 17:49:19 [core.py:586] KeyError: 'vision_model.encoder.layers.0.self_attn.qkv_proj.weight'

It is also missing several of the files related to the multimodal features and other features. Mainly preprocessor_config.json, processor_config.json, model.safetensors.index.json, tokenizer.model, added_tokens.json, chat_template.json. I also noticed that the special_tokens_map.json differs from the base one which seems odd.

GPU Info

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.03             Driver Version: 550.144.03     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100 80GB PCIe          Off |   00000001:00:00.0 Off |                    0 |
| N/A   38C    P0             46W /  300W |       6MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Software Info

Distributor ID: Ubuntu
Description:    Ubuntu 24.04.2 LTS
Release:        24.04
Codename:       noble
Name: accelerate
Version: 1.8.1
Location: /home/ubuntu/grouping-fine-tune/quantize/.venv/lib/python3.12/site-packages
Requires: huggingface-hub, numpy, packaging, psutil, pyyaml, safetensors, torch
Required-by: gptqmodel
---
Name: gptqmodel
Version: 4.0.0.dev0
Location: /home/ubuntu/grouping-fine-tune/quantize/.venv/lib/python3.12/site-packages
Requires: accelerate, device-smi, hf-transfer, huggingface-hub, logbar, numpy, packaging, pillow, protobuf, random-word, safetensors, soundfile, threadpoolctl, tokenicer, torch, transformers
Required-by:
---
Name: torch
Version: 2.7.1
Location: /home/ubuntu/grouping-fine-tune/quantize/.venv/lib/python3.12/site-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-cufile-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-cusparselt-cu12, nvidia-nccl-cu12, nvidia-nvjitlink-cu12, nvidia-nvtx-cu12, setuptools, sympy, triton, typing-extensions
Required-by: accelerate, gptqmodel
---
Name: transformers
Version: 4.52.4
Location: /home/ubuntu/grouping-fine-tune/quantize/.venv/lib/python3.12/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: gptqmodel, tokenicer
---
Name: triton
Version: 3.3.1
Location: /home/ubuntu/grouping-fine-tune/quantize/.venv/lib/python3.12/site-packages
Requires: setuptools
Required-by: torch

If you are reporting an inference bug of a post-quantized model, please post the content of config.json and quantize_config.json.

To Reproduce

Run the code above

Expected behavior

Either the multi-modal layers should be quantized too or the config should not say it's multi-modal. Additionally, (not sure if this should be a separate issue, but the files that were mentioned earlier like preprocessor_config.json should be present and/or correct.

Model/Datasets

https://huggingface.co/unsloth/gemma-3-4b-it

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions