Skip to content

Commit eb4a56e

Browse files
author
Vincent Moens
committed
[BugFix] Better comparison of tensorclasses
ghstack-source-id: 8def6f0 Pull Request resolved: #1137
1 parent d7506c4 commit eb4a56e

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

tensordict/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,11 @@ def assert_close(
15481548

15491549
from tensordict._lazy import LazyStackedTensorDict
15501550

1551+
if is_tensorclass(actual):
1552+
actual = actual._tensordict
1553+
if is_tensorclass(expected):
1554+
expected = expected._tensordict
1555+
15511556
if isinstance(actual, LazyStackedTensorDict) and isinstance(
15521557
expected, LazyStackedTensorDict
15531558
):

0 commit comments

Comments
 (0)