-
Notifications
You must be signed in to change notification settings - Fork 724
Open
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)triagedReviewed and examined, release as been assigned if applicable (status)Reviewed and examined, release as been assigned if applicable (status)
Description
When convert
ing a traced torchvision
model, After applying roi_align
from #1509 AssertionError: type_inference: axis=0, i=1: 256 != is452
Stack Trace
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
/tmp/ipykernel_31355/3386583322.py in <module>
5 traced_model = torch.jit.trace(model_to_trace, example_image_pt).eval()
6
----> 7 detector_mlmodel = ct.convert(traced_model, inputs=[ct.ImageType(shape=(1, 3, 224, 224))])
8 detector_mlmodel.save("segmenter.mlmodel")
/opt/conda/lib/python3.7/site-packages/coremltools/converters/_converters_entry.py in convert(model, source, inputs, outputs, classifier_config, minimum_deployment_target, convert_to, compute_precision, skip_model_load, compute_units, package_dir, debug)
454 package_dir=package_dir,
455 debug=debug,
--> 456 specification_version=specification_version,
457 )
458
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in mil_convert(model, convert_from, convert_to, compute_units, **kwargs)
185 See `coremltools.converters.convert`
186 """
--> 187 return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, **kwargs)
188
189
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in _mil_convert(model, convert_from, convert_to, registry, modelClass, compute_units, **kwargs)
214 convert_to,
215 registry,
--> 216 **kwargs
217 )
218
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in mil_convert_to_proto(model, convert_from, convert_to, converter_registry, **kwargs)
279 frontend_converter = frontend_converter_type()
280
--> 281 prog = frontend_converter(model, **kwargs)
282
283 if convert_to.lower() != "neuralnetwork":
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/converter.py in __call__(self, *args, **kwargs)
107 from .frontend.torch import load
108
--> 109 return load(*args, **kwargs)
110
111
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/load.py in load(model_spec, inputs, specification_version, debug, outputs, cut_at_symbols, **kwargs)
55 inputs = _convert_to_torch_inputtype(inputs)
56 converter = TorchConverter(torchscript, inputs, outputs, cut_at_symbols, specification_version)
---> 57 return _perform_torch_convert(converter, debug)
58
59
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/load.py in _perform_torch_convert(converter, debug)
94 def _perform_torch_convert(converter, debug):
95 try:
---> 96 prog = converter.convert()
97 except RuntimeError as e:
98 if debug and "convert function" in str(e):
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/converter.py in convert(self)
279
280 # Add the rest of the operations
--> 281 convert_nodes(self.context, self.graph)
282
283 graph_outputs = [self.context[name] for name in self.graph.outputs]
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in convert_nodes(context, graph)
87
88 context.prepare_for_conversion(node)
---> 89 add_op(context, node)
90
91 # We've generated all the outputs the graph needs, terminate conversion.
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in scatter(context, node)
5228 mode = 'update'
5229
-> 5230 _scatter(context, inputs, mode, node.name)
5231
5232
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/frontend/torch/ops.py in _scatter(context, inputs, mode, name)
5209 if types.is_scalar(updates.sym_type):
5210 updates = mb.fill(shape=indices.shape, value=updates.val, name=name)
-> 5211 result = mb.scatter_along_axis(data=data, indices=indices, updates=updates,axis=axis, mode=mode, name=name)
5212 context.add(result)
5213
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/ops/registry.py in add_op(cls, **kwargs)
174 op_cls_to_add = op_reg[op_type]
175
--> 176 return cls._add_op(op_cls_to_add, **kwargs)
177
178 setattr(Builder, op_type, add_op)
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/builder.py in _add_op(cls, op_cls, **kwargs)
180 curr_block()._insert_op_before(new_op, before_op=before_op)
181 new_op.build_nested_blocks()
--> 182 new_op.type_value_inference()
183 if len(new_op.outputs) == 1:
184 return new_op.outputs[0]
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/operation.py in type_value_inference(self, overwrite_output)
251 existing _output_vars
252 """
--> 253 output_types = self.type_inference()
254 if not isinstance(output_types, tuple):
255 output_types = (output_types,)
/opt/conda/lib/python3.7/site-packages/coremltools/converters/mil/mil/ops/defs/iOS15/scatter_gather.py in type_inference(self)
431 for i in range(self.data.rank):
432 if i != axis:
--> 433 assert self.data.shape[i] == self.indices.shape[i], f'type_inference: axis={axis}, i={i}: {self.data.shape[i]} != {self.indices.shape[i]}'
434
435 return self.data.sym_type
AssertionError: type_inference: axis=0, i=1: 256 != is452
Steps To Reproduce
import coremltools as ct
import torch, torchvision
from torchvision.transforms import functional as F, InterpolationMode, transforms as T
import requests
from PIL import Image
import numpy as np
from typing import Dict, Tuple, Optional
# Image conversion tools:
class PILToTensor(torch.nn.Module):
def forward(
self, image: torch.Tensor, target: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
image = F.pil_to_tensor(image)
return image, target
class ConvertImageDtype(torch.nn.Module):
def __init__(self, dtype: torch.dtype) -> None:
super().__init__()
self.dtype = dtype
def forward(
self, image: torch.Tensor, target: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]:
image = F.convert_image_dtype(image, self.dtype)
return image, target
# Load the torchvision model
detector_model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
detector_model = detector_model.eval()
# Get a sample image
toTensor = T.PILToTensor()
toFloatTensor = T.ConvertImageDtype(torch.float)
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
example_image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
example_image_np = np.array(example_image)
example_image_pt = toFloatTensor(toTensor(example_image))
example_image_pt = example_image_pt.unsqueeze(0)
# Run the sample through the model to demonstrate the model works
y = detector_model(example_image_pt)
# Make an adaptor to convert the model outputs to a tuple
class FasterRCNN_MobileNetV3_AdapterModel(torch.nn.Module):
"""This adapter is only here to unbox the first output."""
def __init__(self, model, w=2):
super().__init__()
self.model = model
def forward(self, x):
result = self.model(x)
return result[0]['boxes'], result[0]['labels'], result[0]['scores']
adapted_detector_model = FasterRCNN_MobileNetV3_AdapterModel(detector_model)
# Trace and convert the model using coremltools
model_to_trace = adapted_detector_model
with torch.inference_mode():
out = model_to_trace(example_image_pt)
traced_model = torch.jit.trace(model_to_trace, example_image_pt).eval()
detector_mlmodel = ct.convert(traced_model, inputs=[ct.ImageType(shape=example_image_pt.shape)])
detector_mlmodel.save("segmenter.mlmodel")
System environment:
coremltools
version: 6.2- OS: Linux (
Linux foohostname 4.19.0-23-cloud-amd64 #1 SMP Debian 4.19.269-1 (2022-12-20) x86_64 GNU/Linux
) - Any other relevant version information (e.g. PyTorch or TensorFlow version):
- Python: 3.7
- PyTorch: 1.11.1+cu102
- Other libraries installed as dependencies of
coremltools
:
Requirement already satisfied: coremltools==6.2 in /opt/conda/lib/python3.7/site-packages (6.2)
Requirement already satisfied: tqdm in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (4.64.1)
Requirement already satisfied: protobuf<=4.0.0,>=3.1.0 in /home/jupyter/.local/lib/python3.7/site-packages (from coremltools==6.2) (3.20.1)
Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (21.3)
Requirement already satisfied: numpy>=1.14.5 in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (1.21.6)
Requirement already satisfied: sympy in /opt/conda/lib/python3.7/site-packages (from coremltools==6.2) (1.10.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging->coremltools==6.2) (3.0.9)
Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.7/site-packages (from sympy->coremltools==6.2) (1.2.1)
Please advise. Thank you!
Metadata
Metadata
Assignees
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)triagedReviewed and examined, release as been assigned if applicable (status)Reviewed and examined, release as been assigned if applicable (status)