Skip to content

Commit ec6f7be

Browse files
authored
Handle when the value type is not known when serializing shapes (#203)
When the value type is not known, we don't know where to store the shape inside a type proto. This change will make the serializer skip serializing the shape instead of error out. Fix #202 --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 467f583 commit ec6f7be

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

src/onnx_ir/serde.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1975,11 +1975,26 @@ def serialize_type(type_protocol: _protocols.TypeProtocol) -> onnx.TypeProto:
19751975
@_capture_errors(lambda type_proto, from_: repr(from_))
19761976
def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None:
19771977
value_field = type_proto.WhichOneof("value")
1978+
if value_field is None:
1979+
# We cannot write the shape because we do not know where to write it
1980+
logger.warning(
1981+
# TODO(justinchuby): Show more context about the value when move everything to an object
1982+
"The value type for shape %s is not known. Please set type for the value. Skipping serialization",
1983+
from_,
1984+
)
1985+
return
19781986
tensor_type = getattr(type_proto, value_field)
19791987
while not isinstance(tensor_type.elem_type, int):
19801988
# Find the leaf type that has the shape field
19811989
type_proto = tensor_type.elem_type
19821990
value_field = type_proto.WhichOneof("value")
1991+
if value_field is None:
1992+
logger.warning(
1993+
# TODO(justinchuby): Show more context about the value when move everything to an object
1994+
"The value type for shape %s is not known. Please set type for the value. Skipping serialization",
1995+
from_,
1996+
)
1997+
return
19831998
tensor_type = getattr(type_proto, value_field)
19841999
# When from is empty, we still need to set the shape field to an empty list by touching it
19852000
tensor_type.shape.ClearField("dim")

src/onnx_ir/serde_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,15 @@ def test_serialize_attribute(self, _: str, typ: ir.AttributeType, value, expecte
598598
self.assertEqual(deserialized_attr.type, attr.type)
599599
self.assertEqual(deserialized_attr.value, expected)
600600

601+
def test_serialize_shape_into_skips_writing_when_value_type_not_known(self):
602+
shape = ir.Shape((1, 2, 3))
603+
proto = onnx.TypeProto()
604+
self.assertIsNone(proto.WhichOneof("value"))
605+
serde.serialize_shape_into(proto, shape)
606+
self.assertIsNone(proto.WhichOneof("value"))
607+
deserialized = serde.deserialize_type_proto_for_shape(proto)
608+
self.assertIsNone(deserialized, shape)
609+
601610

602611
class QuantizationAnnotationTest(unittest.TestCase):
603612
"""Test that quantization annotations are correctly serialized and deserialized."""

0 commit comments

Comments
 (0)