Skip to content

Commit 7579756

Browse files
author
Tim Joseph
committed
fix(tensor_dict): use get_args for TDCompatible type checking
In Python 3.8+, `isinstance` with a `typing.Union` like `TDCompatible` raises a `TypeError`. The correct way to check against a type alias that is a union is to use `get_args` from the `typing` module. This change updates the type check to `isinstance(value, get_args(TDCompatible))` to ensure correct runtime behavior across supported Python versions.
1 parent 18c5668 commit 7579756

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/tensorcontainer/tensor_dict.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Union,
2828
cast,
2929
overload,
30+
get_args
3031
)
3132

3233
import torch
@@ -177,7 +178,7 @@ def _pytree_flatten(
177178
metadata: Dict[str, Any] = {}
178179

179180
for key, value in self.data.items():
180-
if isinstance(value, TDCompatible):
181+
if isinstance(value, get_args(TDCompatible)):
181182
td_compatible_leaves.append(value)
182183
td_compatible_keys.append(key)
183184
else:

0 commit comments

Comments
 (0)