Skip to content

Commit 32264fc

Browse files
author
Vincent Moens
committed
[BugFix] Fix compile during _check_keys
ghstack-source-id: cfed094 Pull Request resolved: #1239 (cherry picked from commit 2ad9f95)
1 parent 63ef4fd commit 32264fc

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

tensordict/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,7 +1768,8 @@ def _check_keys(
17681768
is_leaf=_is_leaf_nontensor,
17691769
)
17701770
# TODO: compile doesn't like set() over an arbitrary object
1771-
if is_compiling():
1771+
is_comp = is_compiling()
1772+
if is_comp:
17721773
keys_set = {k for k in keys} # noqa: C416
17731774
else:
17741775
keys_set: set[str] = set(keys)
@@ -1781,7 +1782,7 @@ def _check_keys(
17811782
if not strict:
17821783
keys_set = keys_set.intersection(k)
17831784
else:
1784-
if is_compiling():
1785+
if is_comp:
17851786
k = {v for v in k} # noqa: C416
17861787
else:
17871788
k = set(k)
@@ -1790,7 +1791,10 @@ def _check_keys(
17901791
f"got keys {keys} and {set(td.keys())} which are incompatible"
17911792
)
17921793
if strict:
1793-
return list(keys)
1794+
if is_comp:
1795+
return [key for key in keys] # noqa: C416
1796+
else:
1797+
return list(keys)
17941798
return keys_set
17951799

17961800

0 commit comments

Comments
 (0)