Skip to content

Commit 745f68a

Browse files
authored
Require ml_dtypes>=0.5.0 (#188)
Require ml_dtypes>=0.5.0 as a dependency because an older version of it does not have all the dtypes we need. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent aa0039e commit 745f68a

File tree

2 files changed

+2
-8
lines changed

2 files changed

+2
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ classifiers = [
2121
"Programming Language :: Python :: 3.12",
2222
"Programming Language :: Python :: 3.13",
2323
]
24-
dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes"]
24+
dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes>=0.5.0"]
2525

2626
[project.urls]
2727
Homepage = "https://onnx.ai/ir-py"

src/onnx_ir/_enums.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -422,15 +422,9 @@ def __str__(self) -> str:
422422
np.dtype(ml_dtypes.float8_e8m0fnu): DataType.FLOAT8E8M0,
423423
np.dtype(ml_dtypes.int4): DataType.INT4,
424424
np.dtype(ml_dtypes.uint4): DataType.UINT4,
425+
np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1,
425426
}
426427

427-
# TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE
428-
_NP_TYPE_TO_DATA_TYPE.update(
429-
{np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1}
430-
if hasattr(ml_dtypes, "float4_e2m1fn")
431-
else {}
432-
)
433-
434428
# ONNX DataType to Numpy dtype.
435429
_DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
436430

0 commit comments

Comments
 (0)