1212import math
1313import sys
1414import warnings
15- from collections .abc import Hashable
15+ from collections .abc import Collection , Hashable
1616from functools import lru_cache
17- from types import NoneType
1817from typing import (
1918 TYPE_CHECKING ,
2019 Any ,
2120 Final ,
2221 Literal ,
22+ SupportsIndex ,
2323 TypeAlias ,
2424 TypeGuard ,
2525 cast ,
5151 | ndx .Array
5252 | sparse .SparseArray
5353 | torch .Tensor
54- | SupportsArrayNamespace
54+ | SupportsArrayNamespace [ Any ]
5555 )
5656
5757_API_VERSIONS_OLD : Final = frozenset ({"2021.12" , "2022.12" , "2023.12" })
@@ -630,9 +630,9 @@ def your_function(x, y):
630630 raise ValueError (
631631 "The given array does not have an array-api-compat wrapper"
632632 )
633- x = cast (SupportsArrayNamespace , x )
633+ x = cast (" SupportsArrayNamespace[Any]" , x )
634634 namespaces .add (x .__array_namespace__ (api_version = api_version ))
635- elif isinstance (x , int | float | complex | NoneType ) :
635+ elif isinstance (x , int | float | complex ) or x is None :
636636 continue
637637 else :
638638 # TODO: Support Python scalars?
@@ -890,12 +890,10 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
890890
891891
892892@overload
893- def size (x : HasShape [int ]) -> int : ...
893+ def size (x : HasShape [Collection [ SupportsIndex ] ]) -> int : ...
894894@overload
895- def size (x : HasShape [int | None ]) -> int | None : ...
896- @overload
897- def size (x : HasShape [float ]) -> int | None : ... # Dask special case
898- def size (x : HasShape [float | None ]) -> int | None :
895+ def size (x : HasShape [Collection [SupportsIndex | None ]]) -> int | None : ...
896+ def size (x : HasShape [Collection [SupportsIndex | None ]]) -> int | None :
899897 """
900898 Return the total number of elements of x.
901899
@@ -910,9 +908,9 @@ def size(x: HasShape[float | None]) -> int | None:
910908 # Lazy API compliant arrays, such as ndonnx, can contain None in their shape
911909 if None in x .shape :
912910 return None
913- out = math .prod (cast (tuple [ float , ...] , x .shape ))
911+ out = math .prod (cast ("Collection[SupportsIndex]" , x .shape ))
914912 # dask.array.Array.shape can contain NaN
915- return None if math .isnan (out ) else cast ( int , out )
913+ return None if math .isnan (out ) else out
916914
917915
918916@lru_cache (100 )
@@ -1003,7 +1001,7 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
10031001 # on __bool__ (dask is one such example, which however is special-cased above).
10041002
10051003 # Select a single point of the array
1006- s = size (cast (HasShape , x ))
1004+ s = size (cast (" HasShape[Collection[SupportsIndex | None]]" , x ))
10071005 if s is None :
10081006 return True
10091007 xp = array_namespace (x )
0 commit comments