-
Notifications
You must be signed in to change notification settings - Fork 725
Open
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)
Description
🐞Describing the bug
- The issue arises when using
coremltools.optimize.torch.quantizationto quantize a model withw8a8forConv3dlayers. The quantization process does not work properly, causing errors or incorrect behavior when converting the model to Core ML format.
Stack Trace
Traceback (most recent call last):
File "/Users/silveryu/Developer/lightsvd/diffuserskit/torch_quant.py", line 72, in <module>
ct.convert(traced_model,
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/_converters_entry.py", line 635, in convert
mlmodel = mil_convert(
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
proto, mil_program = mil_convert_to_proto(
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 288, in mil_convert_to_proto
prog = frontend_converter(model, **kwargs)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/converter.py", line 108, in __call__
return load(*args, **kwargs)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 88, in load
return _perform_torch_convert(converter, debug)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 151, in _perform_torch_convert
prog = converter.convert()
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1383, in convert
self.convert_const()
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1251, in convert_const
self._add_const(name, val)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1206, in _add_const
compression_op = self._construct_compression_op(val.detach().numpy(), name)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1189, in _construct_compression_op
result = self._construct_quantization_op(val, compression_info, param_name, result)
File "/opt/homebrew/anaconda3/envs/lightsvd/lib/python3.10/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 953, in _construct_quantization_op
raise ValueError(
ValueError: In conv1.weight, the `weight` should have same rank as `scale`, but got (1, 1, 3, 3, 3) vs (1, 1)
To Reproduce
import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from coremltools.optimize.torch.quantization import (
LinearQuantizer,
LinearQuantizerConfig,
ModuleLinearQuantizerConfig,
)
import coremltools as ct
def torch_quantization(
model: torch.nn.Module,
sample_dataloader: DataLoader,
) -> torch.nn.Module:
config = LinearQuantizerConfig(
global_config=ModuleLinearQuantizerConfig(
weight_dtype="qint8",
activation_dtype="quint8",
quantization_scheme="symmetric",
milestones=[0, 1000, 1000, 0],
)
)
quantizer = LinearQuantizer(model, config)
example_inputs = next(iter(sample_dataloader))
quantizer.prepare(example_inputs=example_inputs, inplace=True)
quantizer.step()
# Do a forward pass through the model with calibration data
for data in tqdm.tqdm(sample_dataloader, desc="calibrating"):
with torch.no_grad():
model(data)
quantized_model = quantizer.finalize()
return quantized_model
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv3d(1, 1, 3)
def forward(self, x):
x = self.conv1(x)
return x
class sample_dataset(Dataset):
def __init__(self):
super().__init__()
self.len = 10
def __len__(self):
return self.len
def __getitem__(self, item):
return torch.randn(1, 12, 224, 224, device="mps", dtype=torch.float32)
test_model = SimpleNet().to("mps")
quantized_model = torch_quantization(test_model, DataLoader(sample_dataset()))
with torch.no_grad():
traced_model = torch.jit.trace(quantized_model.to("cpu"), torch.randn(1, 12, 244, 244, dtype=torch.float32))
_ = traced_model(torch.randn(1, 12, 224, 224, dtype=torch.float32))
ct.convert(traced_model,
inputs=[ct.TensorType(name="input", shape=(3, 244, 244))],
convert_to="mlprogram",
minimum_deployment_target=ct.target.iOS18)- This code triggers the issue when performing quantization on a model with a
Conv3dlayer usingw8a8quantization.
System environment (please complete the following information):
- Coremltools: 8.1
- OS: macOS 15.3
- PyTorch: 2.4.0
- Python: 3.10.15
Additional context
- The issue appears specifically when attempting to quantize a model with
Conv3d. Other layers, likeConv2d, seem to work fine.
Metadata
Metadata
Assignees
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)