Skip to content

Commit 2aa455d

Browse files
author
Vincent Moens
committed
[BugFix] Consolidate lazy stacks of non-tensors
ghstack-source-id: afb1480 Pull Request resolved: #1222 (cherry picked from commit 0b901a7)
1 parent 7ff9d90 commit 2aa455d

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

tensordict/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5042,7 +5042,8 @@ def assign(
50425042
cls = type(value)
50435043
if issubclass(cls, torch.Tensor):
50445044
pass
5045-
elif _is_non_tensor(cls):
5045+
# We want to skip NonTensorStacks
5046+
elif _is_non_tensor(cls) and not issubclass(cls, TensorDictBase):
50465047
if requires_metadata:
50475048
metadata_dict["non_tensors"][key] = (
50485049
value.data,
@@ -5410,7 +5411,8 @@ def _view_and_pad(tensor):
54105411
if non_blocking and device.type != "cuda":
54115412
# sync if needed
54125413
self._sync_all()
5413-
torch.cat(items, out=storage)
5414+
if items:
5415+
torch.cat(items, out=storage)
54145416
for v, (k, oldv) in _zip_strict(
54155417
storage.split(flat_size), list(flat_dict.items())
54165418
):

test/test_tensordict.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11335,6 +11335,17 @@ def test_stack(self, non_tensor_data):
1133511335
LazyStackedTensorDict,
1133611336
)
1133711337

11338+
def test_stack_consolidate(self):
11339+
td = torch.stack(
11340+
[
11341+
TensorDict(a="a string", b="b string"),
11342+
TensorDict(a="another string", b="bnother string"),
11343+
]
11344+
)
11345+
tdc = td.consolidate()
11346+
assert (tdc == td).all()
11347+
assert tdc["a"] == ["a string", "another string"]
11348+
1133811349
def test_assign_non_tensor(self):
1133911350
data = TensorDict({}, [1, 10])
1134011351

0 commit comments

Comments
 (0)