|
12 | 12 | import parameterized |
13 | 13 | import torch |
14 | 14 |
|
| 15 | +import onnx_ir as ir |
15 | 16 | from onnx_ir import tensor_adapters |
16 | 17 |
|
17 | 18 |
|
@@ -83,5 +84,65 @@ def test_tobytes(self, dtype: torch.dtype): |
83 | 84 | self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes()) |
84 | 85 |
|
85 | 86 |
|
| 87 | +class TorchDtypeConversionTest(unittest.TestCase): |
| 88 | + @parameterized.parameterized.expand( |
| 89 | + [ |
| 90 | + (ir.DataType.BFLOAT16, torch.bfloat16), |
| 91 | + (ir.DataType.BOOL, torch.bool), |
| 92 | + (ir.DataType.COMPLEX128, torch.complex128), |
| 93 | + (ir.DataType.COMPLEX64, torch.complex64), |
| 94 | + (ir.DataType.FLOAT16, torch.float16), |
| 95 | + (ir.DataType.FLOAT, torch.float32), |
| 96 | + (ir.DataType.DOUBLE, torch.float64), |
| 97 | + (ir.DataType.FLOAT8E4M3FN, torch.float8_e4m3fn), |
| 98 | + (ir.DataType.FLOAT8E4M3FNUZ, torch.float8_e4m3fnuz), |
| 99 | + (ir.DataType.FLOAT8E5M2, torch.float8_e5m2), |
| 100 | + (ir.DataType.FLOAT8E5M2FNUZ, torch.float8_e5m2fnuz), |
| 101 | + (ir.DataType.FLOAT8E8M0, torch.float8_e8m0fnu), # Requires PyTorch 2.7+ |
| 102 | + (ir.DataType.INT16, torch.int16), |
| 103 | + (ir.DataType.INT32, torch.int32), |
| 104 | + (ir.DataType.INT64, torch.int64), |
| 105 | + (ir.DataType.INT8, torch.int8), |
| 106 | + (ir.DataType.UINT8, torch.uint8), |
| 107 | + (ir.DataType.UINT16, torch.uint16), |
| 108 | + (ir.DataType.UINT32, torch.uint32), |
| 109 | + (ir.DataType.UINT64, torch.uint64), |
| 110 | + ] |
| 111 | + ) |
| 112 | + def test_to_torch_dtype(self, onnx_dtype: ir.DataType, expected_torch_dtype: torch.dtype): |
| 113 | + actual = tensor_adapters.to_torch_dtype(onnx_dtype) |
| 114 | + self.assertEqual(actual, expected_torch_dtype) |
| 115 | + |
| 116 | + @parameterized.parameterized.expand( |
| 117 | + [ |
| 118 | + (torch.bfloat16, ir.DataType.BFLOAT16), |
| 119 | + (torch.bool, ir.DataType.BOOL), |
| 120 | + (torch.complex128, ir.DataType.COMPLEX128), |
| 121 | + (torch.complex64, ir.DataType.COMPLEX64), |
| 122 | + (torch.float16, ir.DataType.FLOAT16), |
| 123 | + (torch.float32, ir.DataType.FLOAT), |
| 124 | + (torch.float64, ir.DataType.DOUBLE), |
| 125 | + (torch.float8_e4m3fn, ir.DataType.FLOAT8E4M3FN), |
| 126 | + (torch.float8_e4m3fnuz, ir.DataType.FLOAT8E4M3FNUZ), |
| 127 | + (torch.float8_e5m2, ir.DataType.FLOAT8E5M2), |
| 128 | + (torch.float8_e5m2fnuz, ir.DataType.FLOAT8E5M2FNUZ), |
| 129 | + (torch.float8_e8m0fnu, ir.DataType.FLOAT8E8M0), # Requires PyTorch 2.7+ |
| 130 | + (torch.int16, ir.DataType.INT16), |
| 131 | + (torch.int32, ir.DataType.INT32), |
| 132 | + (torch.int64, ir.DataType.INT64), |
| 133 | + (torch.int8, ir.DataType.INT8), |
| 134 | + (torch.uint8, ir.DataType.UINT8), |
| 135 | + (torch.uint16, ir.DataType.UINT16), |
| 136 | + (torch.uint32, ir.DataType.UINT32), |
| 137 | + (torch.uint64, ir.DataType.UINT64), |
| 138 | + ] |
| 139 | + ) |
| 140 | + def test_from_torch_dtype( |
| 141 | + self, torch_dtype: torch.dtype, expected_onnx_dtype: ir.DataType |
| 142 | + ): |
| 143 | + actual = tensor_adapters.from_torch_dtype(torch_dtype) |
| 144 | + self.assertEqual(actual, expected_onnx_dtype) |
| 145 | + |
| 146 | + |
86 | 147 | if __name__ == "__main__": |
87 | 148 | unittest.main() |
0 commit comments