Skip to content

Commit 488500c

Browse files
author
Vincent Moens
committed
[BugFix] Fix non-deterministic key order in stack (#1230)
(cherry picked from commit c35d7aa)
1 parent d2d30f6 commit 488500c

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

tensordict/_torch_func.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,6 @@ def stack_fn(key, values, is_not_init, is_tensor):
626626
key: stack_fn(key, values, is_not_init, is_tensor)
627627
for key, (values, is_not_init, is_tensor) in out.items()
628628
}
629-
630629
result = clz._new_unsafe(
631630
out,
632631
batch_size=LazyStackedTensorDict._compute_batch_size(

tensordict/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,7 +1757,7 @@ def _check_keys(
17571757
strict: bool = False,
17581758
include_nested: bool = False,
17591759
leaves_only: bool = False,
1760-
) -> set[str]:
1760+
) -> set[str] | list[str]:
17611761
from tensordict.base import _is_leaf_nontensor
17621762

17631763
if not len(list_of_tensordicts):
@@ -1769,27 +1769,29 @@ def _check_keys(
17691769
)
17701770
# TODO: compile doesn't like set() over an arbitrary object
17711771
if is_compiling():
1772-
keys = {k for k in keys} # noqa: C416
1772+
keys_set = {k for k in keys} # noqa: C416
17731773
else:
1774-
keys: set[str] = set(keys)
1774+
keys_set: set[str] = set(keys)
17751775
for td in list_of_tensordicts[1:]:
17761776
k = td.keys(
17771777
include_nested=include_nested,
17781778
leaves_only=leaves_only,
17791779
is_leaf=_is_leaf_nontensor,
17801780
)
17811781
if not strict:
1782-
keys = keys.intersection(k)
1782+
keys_set = keys_set.intersection(k)
17831783
else:
17841784
if is_compiling():
17851785
k = {v for v in k} # noqa: C416
17861786
else:
17871787
k = set(k)
1788-
if k != keys:
1788+
if k != keys_set:
17891789
raise KeyError(
17901790
f"got keys {keys} and {set(td.keys())} which are incompatible"
17911791
)
1792-
return keys
1792+
if strict:
1793+
return list(keys)
1794+
return keys_set
17931795

17941796

17951797
def _set_max_batch_size(source: T, batch_dims=None):

test/test_tensorclass.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,6 +1884,20 @@ class MyDataNested:
18841884
):
18851885
torch.stack([data1, data3], dim=0)
18861886

1887+
def test_stack_keyorder(self):
1888+
1889+
class MyTensorClass(TensorClass):
1890+
foo: Tensor
1891+
bar: Tensor
1892+
1893+
tc1 = MyTensorClass(foo=torch.zeros((1,)), bar=torch.ones((1,)))
1894+
1895+
for _ in range(10000):
1896+
assert list(torch.stack([tc1, tc1], dim=0)._tensordict.keys()) == [
1897+
"foo",
1898+
"bar",
1899+
]
1900+
18871901
def test_statedict_errors(self):
18881902
@tensorclass
18891903
class MyClass:

0 commit comments

Comments
 (0)