Skip to content

Commit ca183e7

Browse files
authored
Fix support for torch.float8_e8m0fnu for older versions of PyTorch (#145)
`torch.float8_e8m0fnu` was added only in PyTorch 2.7. In the `to_torch_dtype` method where we build a dictionary of type mapping from onnx dtypes to torch dtypes, a undefinied member error would be raised if user has a torch version <2.7. This change conditionally adds the float8_e8m0fnu to the mapping if it is definied in torch. When users request the dtype in an environment that has the older version of PyTorch, an error will be raised to suggest users to upgrade torch version. --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent aab601f commit ca183e7

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

src/onnx_ir/tensor_adapters.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
6868
torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
6969
torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
7070
torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
71-
torch.float8_e8m0fnu: ir.DataType.FLOAT8E8M0,
7271
torch.int16: ir.DataType.INT16,
7372
torch.int32: ir.DataType.INT32,
7473
torch.int64: ir.DataType.INT64,
@@ -78,6 +77,10 @@ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
7877
torch.uint32: ir.DataType.UINT32,
7978
torch.uint64: ir.DataType.UINT64,
8079
}
80+
if hasattr(torch, "float8_e8m0fnu"):
81+
# torch.float8_e8m0fnu is available in PyTorch 2.7+
82+
_TORCH_DTYPE_TO_ONNX[torch.float8_e8m0fnu] = ir.DataType.FLOAT8E8M0
83+
8184
if dtype not in _TORCH_DTYPE_TO_ONNX:
8285
raise TypeError(
8386
f"Unsupported PyTorch dtype '{dtype}'. "
@@ -105,7 +108,6 @@ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
105108
ir.DataType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
106109
ir.DataType.FLOAT8E5M2: torch.float8_e5m2,
107110
ir.DataType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
108-
ir.DataType.FLOAT8E8M0: torch.float8_e8m0fnu,
109111
ir.DataType.INT16: torch.int16,
110112
ir.DataType.INT32: torch.int32,
111113
ir.DataType.INT64: torch.int64,
@@ -115,7 +117,17 @@ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
115117
ir.DataType.UINT32: torch.uint32,
116118
ir.DataType.UINT64: torch.uint64,
117119
}
120+
121+
if hasattr(torch, "float8_e8m0fnu"):
122+
# torch.float8_e8m0fnu is available in PyTorch 2.7+
123+
_ONNX_DTYPE_TO_TORCH[ir.DataType.FLOAT8E8M0] = torch.float8_e8m0fnu
124+
118125
if dtype not in _ONNX_DTYPE_TO_TORCH:
126+
if dtype == ir.DataType.FLOAT8E8M0:
127+
raise ValueError(
128+
"The requested DataType 'FLOAT8E8M0' is only supported in PyTorch 2.7+. "
129+
"Please upgrade your PyTorch version to use this dtype."
130+
)
119131
raise TypeError(
120132
f"Unsupported conversion from ONNX dtype '{dtype}' to torch. "
121133
"Please use a supported dtype from the list: "

src/onnx_ir/tensor_adapters_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import parameterized
1313
import torch
1414

15+
import onnx_ir as ir
1516
from onnx_ir import tensor_adapters
1617

1718

@@ -83,5 +84,65 @@ def test_tobytes(self, dtype: torch.dtype):
8384
self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes())
8485

8586

87+
class TorchDtypeConversionTest(unittest.TestCase):
88+
@parameterized.parameterized.expand(
89+
[
90+
(ir.DataType.BFLOAT16, torch.bfloat16),
91+
(ir.DataType.BOOL, torch.bool),
92+
(ir.DataType.COMPLEX128, torch.complex128),
93+
(ir.DataType.COMPLEX64, torch.complex64),
94+
(ir.DataType.FLOAT16, torch.float16),
95+
(ir.DataType.FLOAT, torch.float32),
96+
(ir.DataType.DOUBLE, torch.float64),
97+
(ir.DataType.FLOAT8E4M3FN, torch.float8_e4m3fn),
98+
(ir.DataType.FLOAT8E4M3FNUZ, torch.float8_e4m3fnuz),
99+
(ir.DataType.FLOAT8E5M2, torch.float8_e5m2),
100+
(ir.DataType.FLOAT8E5M2FNUZ, torch.float8_e5m2fnuz),
101+
(ir.DataType.FLOAT8E8M0, torch.float8_e8m0fnu), # Requires PyTorch 2.7+
102+
(ir.DataType.INT16, torch.int16),
103+
(ir.DataType.INT32, torch.int32),
104+
(ir.DataType.INT64, torch.int64),
105+
(ir.DataType.INT8, torch.int8),
106+
(ir.DataType.UINT8, torch.uint8),
107+
(ir.DataType.UINT16, torch.uint16),
108+
(ir.DataType.UINT32, torch.uint32),
109+
(ir.DataType.UINT64, torch.uint64),
110+
]
111+
)
112+
def test_to_torch_dtype(self, onnx_dtype: ir.DataType, expected_torch_dtype: torch.dtype):
113+
actual = tensor_adapters.to_torch_dtype(onnx_dtype)
114+
self.assertEqual(actual, expected_torch_dtype)
115+
116+
@parameterized.parameterized.expand(
117+
[
118+
(torch.bfloat16, ir.DataType.BFLOAT16),
119+
(torch.bool, ir.DataType.BOOL),
120+
(torch.complex128, ir.DataType.COMPLEX128),
121+
(torch.complex64, ir.DataType.COMPLEX64),
122+
(torch.float16, ir.DataType.FLOAT16),
123+
(torch.float32, ir.DataType.FLOAT),
124+
(torch.float64, ir.DataType.DOUBLE),
125+
(torch.float8_e4m3fn, ir.DataType.FLOAT8E4M3FN),
126+
(torch.float8_e4m3fnuz, ir.DataType.FLOAT8E4M3FNUZ),
127+
(torch.float8_e5m2, ir.DataType.FLOAT8E5M2),
128+
(torch.float8_e5m2fnuz, ir.DataType.FLOAT8E5M2FNUZ),
129+
(torch.float8_e8m0fnu, ir.DataType.FLOAT8E8M0), # Requires PyTorch 2.7+
130+
(torch.int16, ir.DataType.INT16),
131+
(torch.int32, ir.DataType.INT32),
132+
(torch.int64, ir.DataType.INT64),
133+
(torch.int8, ir.DataType.INT8),
134+
(torch.uint8, ir.DataType.UINT8),
135+
(torch.uint16, ir.DataType.UINT16),
136+
(torch.uint32, ir.DataType.UINT32),
137+
(torch.uint64, ir.DataType.UINT64),
138+
]
139+
)
140+
def test_from_torch_dtype(
141+
self, torch_dtype: torch.dtype, expected_onnx_dtype: ir.DataType
142+
):
143+
actual = tensor_adapters.from_torch_dtype(torch_dtype)
144+
self.assertEqual(actual, expected_onnx_dtype)
145+
146+
86147
if __name__ == "__main__":
87148
unittest.main()

0 commit comments

Comments
 (0)