Skip to content

Commit 12047d2

Browse files
authored
Support numpy scalars in Tensor (#102)
Numpy scalars (https://numpy.org/doc/2.2/reference/arrays.scalars.html) have `__array__` defined but doesn't behave like normal np arrays sometimes. This change updates the initialization logic in `Tensor` to detect them and turn them into np.ndarray. Fixes #101. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 92fc9d7 commit 12047d2

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

src/onnx_ir/_core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,9 @@ def __init__(
417417
else:
418418
self._shape = shape
419419
self._shape.freeze()
420+
if isinstance(value, np.generic):
421+
# Turn numpy scalar into a numpy array
422+
value = np.array(value) # type: ignore[assignment]
420423
if dtype is None:
421424
if isinstance(value, np.ndarray):
422425
self._dtype = _enums.DataType.from_numpy(value.dtype)

src/onnx_ir/_core_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,17 @@ def test_initialize_with_just_np_array(self):
6969
tensor = _core.Tensor(array)
7070
np.testing.assert_array_equal(tensor, array)
7171

72+
@parameterized.parameterized.expand(
73+
[
74+
("bfloat16", ml_dtypes.bfloat16(0.5)),
75+
("float32", np.float32(0.5)),
76+
("bool", np.bool(True)),
77+
]
78+
)
79+
def test_initialize_with_np_number(self, _: str, number: np.generic):
80+
tensor = _core.Tensor(number)
81+
np.testing.assert_equal(tensor.numpy(), np.array(number), strict=True)
82+
7283
def test_initialize_raises_when_numpy_dtype_doesnt_match(self):
7384
array = np.random.rand(1, 2).astype(np.float32)
7485
with self.assertRaises(TypeError):

0 commit comments

Comments
 (0)