Skip to content

Commit 7b9b90b

Browse files
authored
Support string tensors in DeduplicateInitializersPass and call_onnx_api (#169)
Getting the number of bytes in a string tensor needs special treatment as the STRING data type does not define a bitwidth but needs to be computed from flattening the strings into a sequence of bytes. See call_onnx_api querying .nbytes to temporarily remove large initializers. String initializers need to be converted to bytes via the .string_data method instead of the .tobytes method in DeduplicateInitializersPass. Also adds mapping of `np.bytes_` to `STRING` in `DataType.from_numpy`, which was probably just overlooked before, as object and `np.str_` (unicode strings) are already handled. This is related to microsoft/onnxscript#2514 --------- Signed-off-by: Christoph Berganski <christoph.berganski@gmail.com>
1 parent 40c0fe4 commit 7b9b90b

File tree

5 files changed

+91
-11
lines changed

5 files changed

+91
-11
lines changed

src/onnx_ir/_core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,11 @@ def shape(self) -> Shape:
836836
"""The shape of the tensor. Immutable."""
837837
return self._shape
838838

839+
@property
840+
def nbytes(self) -> int:
841+
"""The number of bytes in the tensor."""
842+
return sum(len(string) for string in self.string_data())
843+
839844
@property
840845
def raw(self) -> Sequence[bytes] | npt.NDArray[np.bytes_]:
841846
"""Backing data of the tensor. Immutable."""

src/onnx_ir/_core_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2485,5 +2485,27 @@ def test_integration_with_regular_tensor_operations(self):
24852485
self.assertEqual(result.sum(), 10) # 1+2+3+4 = 10
24862486

24872487

2488+
class StringTensorTest(unittest.TestCase):
2489+
def test_nbytes(self):
2490+
data = np.array([b"A", b"BC", b"D"])
2491+
tensor = _core.StringTensor(data)
2492+
self.assertEqual(tensor.nbytes, 4)
2493+
2494+
def test_nbytes_2d(self):
2495+
data = np.array([[b"A", b"BC", b"D"], [b"EFG", b"H", b"I"]])
2496+
tensor = _core.StringTensor(data)
2497+
self.assertEqual(tensor.nbytes, 9)
2498+
2499+
def test_nbytes_empty(self):
2500+
data = np.array([])
2501+
tensor = _core.StringTensor(data)
2502+
self.assertEqual(tensor.nbytes, 0)
2503+
2504+
def test_nbytes_single(self):
2505+
data = np.array([b"ABC"])
2506+
tensor = _core.StringTensor(data)
2507+
self.assertEqual(tensor.nbytes, 3)
2508+
2509+
24882510
if __name__ == "__main__":
24892511
unittest.main()

src/onnx_ir/_enums.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def from_numpy(cls, dtype: np.dtype) -> DataType:
7777
if dtype in _NP_TYPE_TO_DATA_TYPE:
7878
return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
7979

80-
if np.issubdtype(dtype, np.str_):
80+
if np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_):
8181
return DataType.STRING
8282

8383
# Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
@@ -215,6 +215,10 @@ def is_signed(self) -> bool:
215215
DataType.FLOAT8E8M0,
216216
}
217217

218+
def is_string(self) -> bool:
219+
"""Returns True if the data type is a string type."""
220+
return self == DataType.STRING
221+
218222
def __repr__(self) -> str:
219223
return self.name
220224

src/onnx_ir/passes/common/initializer_deduplication.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import hashlib
1111
import logging
1212

13+
import numpy as np
14+
1315
import onnx_ir as ir
1416

1517
logger = logging.getLogger(__name__)
@@ -42,17 +44,27 @@ def _should_skip_initializer(initializer: ir.Value, size_limit: int) -> bool:
4244
size_limit,
4345
)
4446
return True
45-
46-
if const_val.dtype == ir.DataType.STRING:
47-
# Skip string initializers as they don't have a bytes representation
48-
logger.warning(
49-
"Skipped deduplication of string initializer '%s' (unsupported yet)",
50-
initializer.name,
51-
)
52-
return True
5347
return False
5448

5549

50+
def _tobytes(val):
51+
"""StringTensor does not support tobytes. Use 'string_data' instead.
52+
53+
However, 'string_data' yields a list of bytes which cannot be hashed, i.e.,
54+
cannot be used to index into a dict. To generate keys for identifying
55+
tensors in initializer deduplication the following converts the list of
56+
bytes to an array of fixed-length strings which can be flattened into a
57+
bytes-string. This, together with the tensor shape, is sufficient for
58+
identifying tensors for deduplication, but it differs from the
59+
representation used for serializing tensors (that is string_data) by adding
60+
padding bytes so that each string occupies the same number of consecutive
61+
bytes in the flattened .tobytes representation.
62+
"""
63+
if val.dtype.is_string():
64+
return np.array(val.string_data()).tobytes()
65+
return val.tobytes()
66+
67+
5668
class DeduplicateInitializersPass(ir.passes.InPlacePass):
5769
"""Remove duplicated initializer tensors from the main graph and all subgraphs.
5870
@@ -84,7 +96,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
8496
const_val = initializer.const_value
8597
assert const_val is not None
8698

87-
key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
99+
key = (const_val.dtype, tuple(const_val.shape), _tobytes(const_val))
88100
if key in initializers:
89101
modified = True
90102
initializer_to_keep = initializers[key] # type: ignore[index]
@@ -143,7 +155,7 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
143155
key = (const_val.dtype, tensor_dims, tensor_digest)
144156

145157
if key in initializers:
146-
if initializers[key].const_value.tobytes() != const_val.tobytes():
158+
if _tobytes(initializers[key].const_value) != _tobytes(const_val):
147159
logger.warning(
148160
"Initializer deduplication failed: "
149161
"hashes match but values differ with values %s and %s",

src/onnx_ir/passes/common/initializer_deduplication_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ def test_deduplicates_identical_initializers(self):
4747
add_node = new_model.graph[0]
4848
self.assertEqual(add_node.inputs[0], add_node.inputs[1])
4949

50+
def test_deduplicates_identical_string_initializers(self):
51+
model = ir.from_onnx_text(
52+
"""
53+
<ir_version: 10, opset_import: ["" : 17]>
54+
agraph () => ()
55+
<string[2] s1 = {"A", "B"}, string[2] s2 = {"A", "B"}> {
56+
}
57+
"""
58+
)
59+
self.assertEqual(len(model.graph.initializers), 2)
60+
new_model = self.apply_pass(model)
61+
self.assertEqual(len(new_model.graph.initializers), 1)
62+
5063
def test_initializers_with_different_shapes_not_deduplicated(self):
5164
model = ir.from_onnx_text(
5265
"""
@@ -60,6 +73,30 @@ def test_initializers_with_different_shapes_not_deduplicated(self):
6073
new_model = self.apply_pass(model)
6174
self.assertEqual(len(new_model.graph.initializers), 2)
6275

76+
def test_string_initializers_with_different_shapes_not_deduplicated(self):
77+
model = ir.from_onnx_text(
78+
"""
79+
<ir_version: 10, opset_import: ["" : 17]>
80+
agraph () => ()
81+
<string[2] s1 = {"A", "B"}, string[1,2] s2 = {"A", "B"}> {
82+
}
83+
"""
84+
)
85+
new_model = self.apply_pass(model)
86+
self.assertEqual(len(new_model.graph.initializers), 2)
87+
88+
def test_string_initializers_with_same_bytes_but_different_grouping_not_deduplicated(self):
89+
model = ir.from_onnx_text(
90+
"""
91+
<ir_version: 10, opset_import: ["" : 17]>
92+
agraph () => ()
93+
<string[2] s1 = {"AB", "C"}, string[2] s2 = {"A", "BC"}> {
94+
}
95+
"""
96+
)
97+
new_model = self.apply_pass(model)
98+
self.assertEqual(len(new_model.graph.initializers), 2)
99+
63100
def test_initializers_with_different_dtypes_not_deduplicated(self):
64101
model = ir.from_onnx_text(
65102
"""

0 commit comments

Comments
 (0)