Skip to content

Commit bfab676

Browse files
authored
Add more properties for numeric types to DataType (#171)
Fix #109 --------- Signed-off-by: Christoph Berganski <christoph.berganski@gmail.com>
1 parent c339165 commit bfab676

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed

src/onnx_ir/_enums.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import enum
8+
from typing import Any
89

910
import ml_dtypes
1011
import numpy as np
@@ -131,6 +132,146 @@ def bitwidth(self) -> int:
131132
raise TypeError(f"Bitwidth not available for ONNX data type: {self}")
132133
return _BITWIDTH_MAP[self]
133134

135+
@property
136+
def exponent_bitwidth(self) -> int:
137+
"""Returns the bit width of the exponent for floating-point types.
138+
139+
.. versionadded:: 0.1.8
140+
141+
Raises:
142+
TypeError: If the data type is not supported.
143+
"""
144+
if self.is_floating_point():
145+
return ml_dtypes.finfo(self.numpy()).nexp
146+
147+
raise TypeError(f"Exponent not available for ONNX data type: {self}")
148+
149+
@property
150+
def mantissa_bitwidth(self) -> int:
151+
"""Returns the bit width of the mantissa for floating-point types.
152+
153+
.. versionadded:: 0.1.8
154+
155+
Raises:
156+
TypeError: If the data type is not supported.
157+
"""
158+
if self.is_floating_point():
159+
return ml_dtypes.finfo(self.numpy()).nmant
160+
161+
raise TypeError(f"Mantissa not available for ONNX data type: {self}")
162+
163+
@property
164+
def eps(self) -> int | np.floating[Any]:
165+
"""Returns the difference between 1.0 and the next smallest representable float larger than 1.0 for the ONNX data type.
166+
167+
Returns 1 for integers.
168+
169+
.. versionadded:: 0.1.8
170+
171+
Raises:
172+
TypeError: If the data type is not a numeric data type.
173+
"""
174+
if self.is_integer():
175+
return 1
176+
177+
if self.is_floating_point():
178+
return ml_dtypes.finfo(self.numpy()).eps
179+
180+
raise TypeError(f"Eps not available for ONNX data type: {self}")
181+
182+
@property
183+
def tiny(self) -> int | np.floating[Any]:
184+
"""Returns the smallest positive non-zero value for the ONNX data type.
185+
186+
Returns 1 for integers.
187+
188+
.. versionadded:: 0.1.8
189+
190+
Raises:
191+
TypeError: If the data type is not a numeric data type.
192+
"""
193+
if self.is_integer():
194+
return 1
195+
196+
if self.is_floating_point():
197+
return ml_dtypes.finfo(self.numpy()).tiny
198+
199+
raise TypeError(f"Tiny not available for ONNX data type: {self}")
200+
201+
@property
202+
def min(self) -> int | np.floating[Any]:
203+
"""Returns the minimum representable value for the ONNX data type.
204+
205+
.. versionadded:: 0.1.8
206+
207+
Raises:
208+
TypeError: If the data type is not a numeric data type.
209+
"""
210+
if self.is_integer():
211+
return ml_dtypes.iinfo(self.numpy()).min
212+
213+
if self.is_floating_point():
214+
return ml_dtypes.finfo(self.numpy()).min
215+
216+
raise TypeError(f"Minimum not available for ONNX data type: {self}")
217+
218+
@property
219+
def max(self) -> int | np.floating[Any]:
220+
"""Returns the maximum representable value for the ONNX data type.
221+
222+
.. versionadded:: 0.1.8
223+
224+
Raises:
225+
TypeError: If the data type is not a numeric data type.
226+
"""
227+
if self.is_integer():
228+
return ml_dtypes.iinfo(self.numpy()).max
229+
230+
if self.is_floating_point():
231+
return ml_dtypes.finfo(self.numpy()).max
232+
233+
raise TypeError(f"Maximum not available for ONNX data type: {self}")
234+
235+
@property
236+
def precision(self) -> int:
237+
"""Returns the precision for the ONNX dtype if supported.
238+
239+
For floats returns the approximate number of decimal digits to which
240+
this kind of float is precise. Returns 0 for integers.
241+
242+
.. versionadded:: 0.1.8
243+
244+
Raises:
245+
TypeError: If the data type is not a numeric data type.
246+
"""
247+
if self.is_integer():
248+
return 0
249+
250+
if self.is_floating_point():
251+
return ml_dtypes.finfo(self.numpy()).precision
252+
253+
raise TypeError(f"Precision not available for ONNX data type: {self}")
254+
255+
@property
256+
def resolution(self) -> int | np.floating[Any]:
257+
"""Returns the resolution for the ONNX dtype if supported.
258+
259+
Returns the approximate decimal resolution of this type, i.e.,
260+
10**-precision. Returns 1 for integers.
261+
262+
.. versionadded:: 0.1.8
263+
264+
Raises:
265+
TypeError: If the data type is not a numeric data type.
266+
"""
267+
if self.is_integer():
268+
return 1
269+
270+
if self.is_floating_point():
271+
return ml_dtypes.finfo(self.numpy()).resolution
272+
273+
raise TypeError(f"Resolution not available for ONNX data type: {self}")
274+
134275
def numpy(self) -> np.dtype:
135276
"""Returns the numpy dtype for the ONNX data type.
136277

0 commit comments

Comments
 (0)