Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions onnxscript/type_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,13 @@ def onnx_attr_type_to_onnxscript_repr(attr_type: onnx.AttributeProto.AttributeTy
)


def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue:
"""Remove Annotated wrapper if present, otherwise return typeinfo as is."""
if hasattr(typing, "Annotated"):
# Present in Python 3.9+
if typing.get_origin(typeinfo) is typing.Annotated:
return typing.get_args(typeinfo)[0]
return typeinfo


def _is_primitive_attr_type(typeinfo: TypeAnnotationValue) -> bool:
return typeinfo in _PYTYPE_TO_ATTRTYPE_MAP


def pytype_to_attrtype(
pytype: TypeAnnotationValue,
) -> Optional[onnx.AttributeProto.AttributeType]:
pytype = _remove_annotation(pytype)
if pytype in _PYTYPE_TO_ATTRTYPE_MAP:
return _PYTYPE_TO_ATTRTYPE_MAP[pytype]
type_constructor = typing.get_origin(pytype)
Expand All @@ -117,7 +107,6 @@ def pytype_to_attrtype(

def base_type_is_bool(pytype: TypeAnnotationValue) -> bool:
"""Returns True if base type of pytype is bool, False otherwise."""
pytype = _remove_annotation(pytype)
if pytype in _PYTYPE_TO_ATTRTYPE_MAP:
return pytype is bool
type_constructor = typing.get_origin(pytype)
Expand Down
4 changes: 2 additions & 2 deletions opgen/onnx_opset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pathlib import Path
from textwrap import dedent
from typing import Annotated, Any, Iterable, Optional, Set, TextIO
from typing import Any, Iterable, Optional, Set, TextIO

import onnx
import pygen as cg
Expand All @@ -32,7 +32,7 @@
MODULE_ONNX_SCRIPT_VALUES = "onnxscript.values"


OpsetId = tuple[Annotated[str, "domain"], Annotated[int, "version"]]
OpsetId = tuple[str, int]


def parse_opsetid(opsetid: str) -> OpsetId:
Expand Down