Skip to content

Commit 3e0c2d8

Browse files
author
Vincent Moens
committed
[BugFix] Consolidate lazy stacks of non-tensors
ghstack-source-id: d3d822d Pull Request resolved: #1224 (cherry picked from commit f67a15c)
1 parent 2aa455d commit 3e0c2d8

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

tensordict/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
lazy_legacy,
9797
lock_blocked,
9898
prod,
99+
set_capture_non_tensor_stack,
99100
set_lazy_legacy,
100101
strtobool,
101102
TensorDictFuture,
@@ -9078,7 +9079,8 @@ def newfn(item_and_out):
90789079
from tensordict._lazy import LazyStackedTensorDict
90799080

90809081
# We want to be able to return whichever data structure
9081-
out = LazyStackedTensorDict.maybe_dense_stack(imaplist, dim)
9082+
with set_capture_non_tensor_stack(False):
9083+
out = LazyStackedTensorDict.maybe_dense_stack(imaplist, dim)
90829084
else:
90839085
out = torch.cat(imaplist, dim)
90849086
return out

test/test_tensordict.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939

4040
from tensordict import (
4141
get_defaults_to_none,
42+
lazy_legacy,
43+
lazy_stack,
4244
LazyStackedTensorDict,
4345
make_tensordict,
4446
PersistentTensorDict,
@@ -65,7 +67,6 @@
6567
convert_ellipsis_to_idx,
6668
is_non_tensor,
6769
is_tensorclass,
68-
lazy_legacy,
6970
logger as tdlogger,
7071
set_lazy_legacy,
7172
)
@@ -11190,6 +11191,11 @@ def test_map_iter_interrupt_early(self, chunksize, num_chunks, shuffle):
1119011191

1119111192

1119211193
class TestNonTensorData:
11194+
@tensorclass
11195+
class SomeTensorClass:
11196+
a: str
11197+
b: torch.Tensor
11198+
1119311199
@pytest.fixture
1119411200
def non_tensor_data(self):
1119511201
return TensorDict(
@@ -11204,6 +11210,27 @@ def non_tensor_data(self):
1120411210
batch_size=[],
1120511211
)
1120611212

11213+
@set_capture_non_tensor_stack(False)
11214+
def test_consolidate_nested(self):
11215+
import pickle
11216+
11217+
td = TensorDict(
11218+
a=TensorDict(b=self.SomeTensorClass(a="a string!", b=torch.randn(10))),
11219+
c=TensorDict(d=NonTensorData("another string!")),
11220+
)
11221+
td = lazy_stack([td.clone(), td.clone()])
11222+
td = lazy_stack([td.clone(), td.clone()], -1)
11223+
11224+
tdc = td.consolidate()
11225+
11226+
assert (tdc == td).all()
11227+
11228+
tdr = pickle.loads(pickle.dumps(td))
11229+
assert (tdr == td).all()
11230+
11231+
tdcr = pickle.loads(pickle.dumps(tdc))
11232+
assert (tdcr == td).all()
11233+
1120711234
def test_comparison(self, non_tensor_data):
1120811235
non_tensor_data = non_tensor_data.exclude(("nested", "str"))
1120911236
assert (non_tensor_data | non_tensor_data).get_non_tensor(("nested", "bool"))

0 commit comments

Comments
 (0)