Skip to content

Commit 554bf54

Browse files
author
Vincent Moens
committed
[BugFix] _PASSTHROUGH_MEMO for passthrough tensorclass
ghstack-source-id: 0bfbfc9 Pull Request resolved: #1231 (cherry picked from commit d25bd54)
1 parent 488500c commit 554bf54

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tensordict/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,17 +2496,20 @@ def _is_non_tensor(cls: type):
24962496
return out
24972497

24982498

2499+
_PASSTHROUGH_MEMO = {}
2500+
2501+
24992502
def _pass_through_cls(cls: type):
25002503
out = None
25012504
is_dynamo = is_compiling()
25022505
if not is_dynamo:
2503-
out = _NON_TENSOR_MEMO.get(cls)
2506+
out = _PASSTHROUGH_MEMO.get(cls)
25042507
if out is None:
25052508
out = bool(getattr(cls, "_is_non_tensor", False)) or getattr(
25062509
cls, "_pass_through", False
25072510
)
25082511
if not is_dynamo:
2509-
_NON_TENSOR_MEMO[cls] = out
2512+
_PASSTHROUGH_MEMO[cls] = out
25102513
return out
25112514

25122515

0 commit comments

Comments
 (0)