From d6eccaa50475cdcf06175c92b38419efc04258b6 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Sun, 26 Oct 2025 19:38:51 +0530 Subject: [PATCH 1/4] Fix attention mask to use float_lowest instead of -inf and add unit test for softmax NaN case --- onnxscript/function_libs/torch_lib/ops/nn.py | 6 +++++- tests/common/testutils.py | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4f81cc7907..65bb2aa079 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -14,6 +14,7 @@ from __future__ import annotations +import numpy as np import math from typing import Optional, Sequence, Tuple, TypeVar, Union @@ -2048,6 +2049,9 @@ def _aten_scaled_dot_product_attention_no_mask_onnx( attn_weight, _ = op.Dropout(attn_weight, dropout_p) return op.MatMul(attn_weight, value) +def float_lowest(dtype): + """Returns the lowest representable value for the given numpy dtype.""" + return np.finfo(np.dtype(dtype)).min def _aten_scaled_dot_product_attention_bool_mask_onnx( query: TFloat, @@ -2078,7 +2082,7 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale)) # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype)) - neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype)) + neg_inf = op.Constant(value=ir.tensor(float_lowest(query.dtype)), dtype=query.dtype) attn_mask = op.Where(attn_mask, zero, neg_inf) attn_weight = op.Softmax( op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), diff --git a/tests/common/testutils.py b/tests/common/testutils.py index 2a2697b240..1db673eab8 100644 --- a/tests/common/testutils.py +++ b/tests/common/testutils.py @@ -14,6 +14,7 @@ import torch from onnxscript import optimizer +from onnxscript.onnx_opset import opset18 as op from onnxscript.rewriter import onnxruntime as ort_rewriter from onnxscript.utils import evaluation_utils @@ -101,3 +102,9 @@ def test_onnxruntime_rewrite( f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}" ) raise + +def test_softmax_with_all_inf_mask(): + # GH #2561 + input = np.array([[-float("inf"), -float("inf")]], dtype=np.float32) + output = op.Softmax(input, axis=-1) + assert np.isnan(output).all(), "Softmax should return NaN when all inputs are -inf" From 0d7c411d1e381322178f0e5fe252eecd54d4d25a Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Thu, 13 Nov 2025 20:25:40 +0530 Subject: [PATCH 2/4] Remove helper function and test --- onnxscript/function_libs/torch_lib/ops/nn.py | 6 +----- tests/common/testutils.py | 7 ------- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 65bb2aa079..6cce402ddf 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -14,7 +14,6 @@ from __future__ import annotations -import numpy as np import math from typing import Optional, Sequence, Tuple, TypeVar, Union @@ -2049,9 +2048,6 @@ def _aten_scaled_dot_product_attention_no_mask_onnx( attn_weight, _ = op.Dropout(attn_weight, dropout_p) return op.MatMul(attn_weight, value) -def float_lowest(dtype): - """Returns the lowest representable value for the given numpy dtype.""" - return np.finfo(np.dtype(dtype)).min def _aten_scaled_dot_product_attention_bool_mask_onnx( query: TFloat, @@ -2082,7 +2078,7 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale)) # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype)) - neg_inf = op.Constant(value=ir.tensor(float_lowest(query.dtype)), dtype=query.dtype) + neg_inf = op.Constant(value=ir.tensor(query.dtype.min), dtype=query.dtype) attn_mask = op.Where(attn_mask, zero, neg_inf) attn_weight = op.Softmax( op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask), diff --git a/tests/common/testutils.py b/tests/common/testutils.py index 1db673eab8..2a2697b240 100644 --- a/tests/common/testutils.py +++ b/tests/common/testutils.py @@ -14,7 +14,6 @@ import torch from onnxscript import optimizer -from onnxscript.onnx_opset import opset18 as op from onnxscript.rewriter import onnxruntime as ort_rewriter from onnxscript.utils import evaluation_utils @@ -102,9 +101,3 @@ def test_onnxruntime_rewrite( f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}" ) raise - -def test_softmax_with_all_inf_mask(): - # GH #2561 - input = np.array([[-float("inf"), -float("inf")]], dtype=np.float32) - output = op.Softmax(input, axis=-1) - assert np.isnan(output).all(), "Softmax should return NaN when all inputs are -inf" From d84309e96c94f920f402a632bc1add56d3c93423 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Fri, 14 Nov 2025 18:50:12 +0530 Subject: [PATCH 3/4] Remove usage of python typing for programmatic annotation --- onnxscript/type_annotation.py | 11 ----------- opgen/onnx_opset_builder.py | 4 ++-- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index fb7b8a370d..345a10309f 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -82,15 +82,6 @@ def onnx_attr_type_to_onnxscript_repr(attr_type: onnx.AttributeProto.AttributeTy ) -def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue: - """Remove Annotated wrapper if present, otherwise return typeinfo as is.""" - if hasattr(typing, "Annotated"): - # Present in Python 3.9+ - if typing.get_origin(typeinfo) is typing.Annotated: - return typing.get_args(typeinfo)[0] - return typeinfo - - def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: return typeinfo in _PYTYPE_TO_ATTRTYPE_MAP @@ -98,7 +89,6 @@ def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool: def pytype_to_attrtype( pytype: TypeAnnotationValue, ) -> Optional[onnx.AttributeProto.AttributeType]: - pytype = _remove_annotation(pytype) if pytype in _PYTYPE_TO_ATTRTYPE_MAP: return _PYTYPE_TO_ATTRTYPE_MAP[pytype] type_constructor = typing.get_origin(pytype) @@ -117,7 +107,6 @@ def pytype_to_attrtype( def base_type_is_bool(pytype: TypeAnnotationValue) -> bool: """Returns True if base type of pytype is bool, False otherwise.""" - pytype = _remove_annotation(pytype) if pytype in _PYTYPE_TO_ATTRTYPE_MAP: return pytype is bool type_constructor = typing.get_origin(pytype) diff --git a/opgen/onnx_opset_builder.py b/opgen/onnx_opset_builder.py index f5c3c0daab..d898376231 100644 --- a/opgen/onnx_opset_builder.py +++ b/opgen/onnx_opset_builder.py @@ -5,7 +5,7 @@ from pathlib import Path from textwrap import dedent -from typing import Annotated, Any, Iterable, Optional, Set, TextIO +from typing import Any, Iterable, Optional, Set, TextIO import onnx import pygen as cg @@ -32,7 +32,7 @@ MODULE_ONNX_SCRIPT_VALUES = "onnxscript.values" -OpsetId = tuple[Annotated[str, "domain"], Annotated[int, "version"]] +OpsetId = tuple[str, int] def parse_opsetid(opsetid: str) -> OpsetId: From 37b3fc165a86d66f2d18db5c038eaacdc5a633f4 Mon Sep 17 00:00:00 2001 From: Aniket Singh Yadav Date: Fri, 14 Nov 2025 18:56:11 +0530 Subject: [PATCH 4/4] Remove usage of python typing for programmatic annotation --- onnxscript/function_libs/torch_lib/ops/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 6cce402ddf..4f81cc7907 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2078,7 +2078,7 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx( key_transposed_scaled = op.Mul(key_transposed, op.Sqrt(scale)) # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) zero = op.Constant(value=ir.tensor(0.0, dtype=query.dtype)) - neg_inf = op.Constant(value=ir.tensor(query.dtype.min), dtype=query.dtype) + neg_inf = op.Constant(value=ir.tensor(-float("inf"), dtype=query.dtype)) attn_mask = op.Where(attn_mask, zero, neg_inf) attn_weight = op.Softmax( op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),