Skip to content

Commit 467f583

Browse files
authored
Create is_unknown and has_unknown_dim on Shape (#201)
This PR enables usages described in #199. This pull request enhances the handling and documentation of unknown dimensions in the `Shape` class within the ONNX IR core. It introduces new methods to check for unknown dimensions, updates the documentation to clarify shape equality with unknowns, and adds comprehensive tests to ensure correct behavior. Additionally, it clarifies the serialization logic for symbolic dimensions with unknown values. **Shape API improvements:** * Added `is_unknown(dim)` and `has_unknown_dim()` methods to the `Shape` class to check for unknown (i.e., `None`) dimensions, and updated the class docstring to clarify shape equality semantics when unknowns are present. **Serialization:** * Clarified that a symbolic dimension with `None` value is valid and should result in an empty field in the serialized proto, reflecting ONNX semantics for unknown dimensions. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent e6a9b8b commit 467f583

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

src/onnx_ir/_core.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,12 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
12191219
12201220
Use :meth:`get_denotation` and :meth:`set_denotation` to access and modify the denotations.
12211221
1222+
.. note::
1223+
Two shapes can be compared for equality. Be careful when comparing shapes with
1224+
unknown dimensions (``None``), as they may not be considered semantically equal
1225+
even if all dimensions are the same. You can use :meth:`has_unknown_dim` to
1226+
check if a shape has any unknown dimensions.
1227+
12221228
Example::
12231229
12241230
>>> import onnx_ir as ir
@@ -1427,6 +1433,29 @@ def is_dynamic(self, dim=None) -> bool:
14271433
return not self.is_static()
14281434
return not self.is_static(dim)
14291435

1436+
def is_unknown_dim(self, dim: int) -> bool:
1437+
"""Return True if the dimension is unknown (None).
1438+
1439+
A dynamic dimension without a symbolic name is considered unknown.
1440+
1441+
.. versionadded:: 0.1.10
1442+
1443+
Args:
1444+
dim: The index of the dimension.
1445+
"""
1446+
dim_obj = self._dims[dim]
1447+
return isinstance(dim_obj, SymbolicDim) and dim_obj.value is None
1448+
1449+
def has_unknown_dim(self) -> bool:
1450+
"""Return True if any dimension is unknown (None).
1451+
1452+
You can use :meth:`is_unknown_dim` to check if a specific dimension is unknown.
1453+
1454+
.. versionadded:: 0.1.10
1455+
"""
1456+
# We can use "in" directly because SymbolicDim implements __eq__ with None
1457+
return None in self._dims
1458+
14301459

14311460
def _quoted(string: str) -> str:
14321461
"""Return a quoted string.

src/onnx_ir/_core_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,45 @@ def test_is_dynamic_on_empty_shape(self):
743743
shape = _core.Shape(())
744744
self.assertFalse(shape.is_dynamic())
745745

746+
def test_is_unknown_dim(self):
747+
shape = _core.Shape([42, None, "any string", None])
748+
self.assertFalse(shape.is_unknown_dim(0)) # integer dimension is not unknown
749+
self.assertTrue(shape.is_unknown_dim(1)) # None dimension is unknown
750+
self.assertFalse(
751+
shape.is_unknown_dim(2)
752+
) # string dimension is not unknown (it's symbolic)
753+
self.assertTrue(shape.is_unknown_dim(3)) # None dimension is unknown
754+
755+
def test_is_unknown_dim_raises_when_index_out_of_range(self):
756+
shape = _core.Shape([42])
757+
with self.assertRaises(IndexError):
758+
shape.is_unknown_dim(1)
759+
760+
def test_has_unknown_dim(self):
761+
# Shape with unknown dimensions
762+
shape = _core.Shape([42, None, "any string"])
763+
self.assertTrue(shape.has_unknown_dim())
764+
765+
# Shape with only None dimensions
766+
shape = _core.Shape([None, None])
767+
self.assertTrue(shape.has_unknown_dim())
768+
769+
# Shape with no unknown dimensions (static and symbolic)
770+
shape = _core.Shape([42, "any string", 64])
771+
self.assertFalse(shape.has_unknown_dim())
772+
773+
# Shape with only static dimensions
774+
shape = _core.Shape([42, 64, 128])
775+
self.assertFalse(shape.has_unknown_dim())
776+
777+
# Shape with only symbolic dimensions
778+
shape = _core.Shape(["batch", "height", "width"])
779+
self.assertFalse(shape.has_unknown_dim())
780+
781+
def test_has_unknown_dim_on_empty_shape(self):
782+
shape = _core.Shape(())
783+
self.assertFalse(shape.has_unknown_dim())
784+
746785

747786
class ValueTest(unittest.TestCase):
748787
def setUp(self) -> None:

src/onnx_ir/serde.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2000,5 +2000,8 @@ def serialize_dimension_into(
20002000
dim_proto.dim_value = dim
20012001
elif isinstance(dim, (_core.SymbolicDim, _protocols.SymbolicDimProtocol)):
20022002
if dim.value is not None:
2003-
# TODO(justinchuby): None is probably not a valid value for dim_param
20042003
dim_proto.dim_param = str(dim.value)
2004+
# NOTE: None is a valid value for symbolic dimension:
2005+
# A dimension MAY have neither dim_value nor dim_param set. Such a dimension
2006+
# represents an unknown dimension unrelated to other unknown dimensions.
2007+
# Here we will just leave the dim_proto empty.

0 commit comments

Comments
 (0)