We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 488500c commit 554bf54Copy full SHA for 554bf54
tensordict/utils.py
@@ -2496,17 +2496,20 @@ def _is_non_tensor(cls: type):
2496
return out
2497
2498
2499
+_PASSTHROUGH_MEMO = {}
2500
+
2501
2502
def _pass_through_cls(cls: type):
2503
out = None
2504
is_dynamo = is_compiling()
2505
if not is_dynamo:
- out = _NON_TENSOR_MEMO.get(cls)
2506
+ out = _PASSTHROUGH_MEMO.get(cls)
2507
if out is None:
2508
out = bool(getattr(cls, "_is_non_tensor", False)) or getattr(
2509
cls, "_pass_through", False
2510
)
2511
- _NON_TENSOR_MEMO[cls] = out
2512
+ _PASSTHROUGH_MEMO[cls] = out
2513
2514
2515
0 commit comments