|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | 7 | import enum |
| 8 | +from typing import Any |
8 | 9 |
|
9 | 10 | import ml_dtypes |
10 | 11 | import numpy as np |
@@ -131,6 +132,146 @@ def bitwidth(self) -> int: |
131 | 132 | raise TypeError(f"Bitwidth not available for ONNX data type: {self}") |
132 | 133 | return _BITWIDTH_MAP[self] |
133 | 134 |
|
| 135 | + @property |
| 136 | + def exponent_bitwidth(self) -> int: |
| 137 | + """Returns the bit width of the exponent for floating-point types. |
| 138 | +
|
| 139 | + .. versionadded:: 0.1.8 |
| 140 | +
|
| 141 | + Raises: |
| 142 | + TypeError: If the data type is not supported. |
| 143 | + """ |
| 144 | + if self.is_floating_point(): |
| 145 | + return ml_dtypes.finfo(self.numpy()).nexp |
| 146 | + |
| 147 | + raise TypeError(f"Exponent not available for ONNX data type: {self}") |
| 148 | + |
| 149 | + @property |
| 150 | + def mantissa_bitwidth(self) -> int: |
| 151 | + """Returns the bit width of the mantissa for floating-point types. |
| 152 | +
|
| 153 | + .. versionadded:: 0.1.8 |
| 154 | +
|
| 155 | + Raises: |
| 156 | + TypeError: If the data type is not supported. |
| 157 | + """ |
| 158 | + if self.is_floating_point(): |
| 159 | + return ml_dtypes.finfo(self.numpy()).nmant |
| 160 | + |
| 161 | + raise TypeError(f"Mantissa not available for ONNX data type: {self}") |
| 162 | + |
| 163 | + @property |
| 164 | + def eps(self) -> int | np.floating[Any]: |
| 165 | + """Returns the difference between 1.0 and the next smallest representable float larger than 1.0 for the ONNX data type. |
| 166 | +
|
| 167 | + Returns 1 for integers. |
| 168 | +
|
| 169 | + .. versionadded:: 0.1.8 |
| 170 | +
|
| 171 | + Raises: |
| 172 | + TypeError: If the data type is not a numeric data type. |
| 173 | + """ |
| 174 | + if self.is_integer(): |
| 175 | + return 1 |
| 176 | + |
| 177 | + if self.is_floating_point(): |
| 178 | + return ml_dtypes.finfo(self.numpy()).eps |
| 179 | + |
| 180 | + raise TypeError(f"Eps not available for ONNX data type: {self}") |
| 181 | + |
| 182 | + @property |
| 183 | + def tiny(self) -> int | np.floating[Any]: |
| 184 | + """Returns the smallest positive non-zero value for the ONNX data type. |
| 185 | +
|
| 186 | + Returns 1 for integers. |
| 187 | +
|
| 188 | + .. versionadded:: 0.1.8 |
| 189 | +
|
| 190 | + Raises: |
| 191 | + TypeError: If the data type is not a numeric data type. |
| 192 | + """ |
| 193 | + if self.is_integer(): |
| 194 | + return 1 |
| 195 | + |
| 196 | + if self.is_floating_point(): |
| 197 | + return ml_dtypes.finfo(self.numpy()).tiny |
| 198 | + |
| 199 | + raise TypeError(f"Tiny not available for ONNX data type: {self}") |
| 200 | + |
| 201 | + @property |
| 202 | + def min(self) -> int | np.floating[Any]: |
| 203 | + """Returns the minimum representable value for the ONNX data type. |
| 204 | +
|
| 205 | + .. versionadded:: 0.1.8 |
| 206 | +
|
| 207 | + Raises: |
| 208 | + TypeError: If the data type is not a numeric data type. |
| 209 | + """ |
| 210 | + if self.is_integer(): |
| 211 | + return ml_dtypes.iinfo(self.numpy()).min |
| 212 | + |
| 213 | + if self.is_floating_point(): |
| 214 | + return ml_dtypes.finfo(self.numpy()).min |
| 215 | + |
| 216 | + raise TypeError(f"Minimum not available for ONNX data type: {self}") |
| 217 | + |
| 218 | + @property |
| 219 | + def max(self) -> int | np.floating[Any]: |
| 220 | + """Returns the maximum representable value for the ONNX data type. |
| 221 | +
|
| 222 | + .. versionadded:: 0.1.8 |
| 223 | +
|
| 224 | + Raises: |
| 225 | + TypeError: If the data type is not a numeric data type. |
| 226 | + """ |
| 227 | + if self.is_integer(): |
| 228 | + return ml_dtypes.iinfo(self.numpy()).max |
| 229 | + |
| 230 | + if self.is_floating_point(): |
| 231 | + return ml_dtypes.finfo(self.numpy()).max |
| 232 | + |
| 233 | + raise TypeError(f"Maximum not available for ONNX data type: {self}") |
| 234 | + |
| 235 | + @property |
| 236 | + def precision(self) -> int: |
| 237 | + """Returns the precision for the ONNX dtype if supported. |
| 238 | +
|
| 239 | + For floats returns the approximate number of decimal digits to which |
| 240 | + this kind of float is precise. Returns 0 for integers. |
| 241 | +
|
| 242 | + .. versionadded:: 0.1.8 |
| 243 | +
|
| 244 | + Raises: |
| 245 | + TypeError: If the data type is not a numeric data type. |
| 246 | + """ |
| 247 | + if self.is_integer(): |
| 248 | + return 0 |
| 249 | + |
| 250 | + if self.is_floating_point(): |
| 251 | + return ml_dtypes.finfo(self.numpy()).precision |
| 252 | + |
| 253 | + raise TypeError(f"Precision not available for ONNX data type: {self}") |
| 254 | + |
| 255 | + @property |
| 256 | + def resolution(self) -> int | np.floating[Any]: |
| 257 | + """Returns the resolution for the ONNX dtype if supported. |
| 258 | +
|
| 259 | + Returns the approximate decimal resolution of this type, i.e., |
| 260 | + 10**-precision. Returns 1 for integers. |
| 261 | +
|
| 262 | + .. versionadded:: 0.1.8 |
| 263 | +
|
| 264 | + Raises: |
| 265 | + TypeError: If the data type is not a numeric data type. |
| 266 | + """ |
| 267 | + if self.is_integer(): |
| 268 | + return 1 |
| 269 | + |
| 270 | + if self.is_floating_point(): |
| 271 | + return ml_dtypes.finfo(self.numpy()).resolution |
| 272 | + |
| 273 | + raise TypeError(f"Resolution not available for ONNX data type: {self}") |
| 274 | + |
134 | 275 | def numpy(self) -> np.dtype: |
135 | 276 | """Returns the numpy dtype for the ONNX data type. |
136 | 277 |
|
|
0 commit comments