Skip to content

Commit ee585a3

Browse files
author
Vincent Moens
committed
[BugFix] Fix serialization of stacks of Tensorclasses
ghstack-source-id: 8e47f46 Pull Request resolved: #1236 (cherry picked from commit 635c9c0)
1 parent d4e27d5 commit ee585a3

File tree

4 files changed

+29
-3
lines changed

4 files changed

+29
-3
lines changed

tensordict/_lazy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def from_dict(
381381
stack_dim_name=None,
382382
stack_dim=0,
383383
):
384-
return LazyStackedTensorDict(
384+
return cls._new_lazy_unsafe(
385385
*(
386386
TensorDict.from_dict(
387387
input_dict[str(i)],

tensordict/_reductions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,12 @@ def from_metadata(metadata=metadata, prefix=None):
9191
_ = metadata.pop("size", None)
9292

9393
d = {
94-
key: NonTensorData(data, batch_size=batch_size)
95-
for (key, (data, batch_size)) in non_tensor.items()
94+
key: NonTensorData(
95+
data,
96+
batch_size=batch_size,
97+
device=torch.device(device) if device is not None else None,
98+
)
99+
for (key, (data, batch_size, device)) in non_tensor.items()
96100
}
97101
for key, (dtype, local_shape, start, stop, pad) in leaves.items():
98102
dtype = _STRDTYPE2DTYPE[dtype]

tensordict/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5055,6 +5055,7 @@ def assign(
50555055
metadata_dict["non_tensors"][key] = (
50565056
value.data,
50575057
list(value.batch_size),
5058+
str(value.device) if value.device is not None else None,
50585059
)
50595060
return
50605061
elif _is_tensor_collection(cls):

test/test_tensorclass.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ class MyDataClass:
233233
MyTensorClass_autocast = MyTensorClass_nocast = MyTensorClass = None
234234

235235

236+
@tensorclass
237+
class TCStrings:
238+
a: str
239+
b: str
240+
241+
236242
class TestTensorClass:
237243
def test_get_default(self):
238244
@tensorclass
@@ -1250,6 +1256,21 @@ def test_pickle(self):
12501256
assert isinstance(data2, MyData)
12511257
assert data2.z == data.z
12521258

1259+
@pytest.mark.parametrize("consolidate", [False, True])
1260+
def test_pickle_consolidate(self, consolidate):
1261+
with set_capture_non_tensor_stack(False):
1262+
1263+
tc = TCStrings(a="a", b="b")
1264+
1265+
tcstack = TensorDict(tc=torch.stack([tc, tc.clone()]))
1266+
if consolidate:
1267+
tcstack = tcstack.consolidate()
1268+
assert isinstance(tcstack["tc"], TCStrings)
1269+
loaded = pickle.loads(pickle.dumps(tcstack))
1270+
assert isinstance(loaded["tc"], TCStrings)
1271+
assert loaded["tc"].a == tcstack["tc"].a
1272+
assert loaded["tc"].b == tcstack["tc"].b
1273+
12531274
def test_post_init(self):
12541275
@tensorclass
12551276
class MyDataPostInit:

0 commit comments

Comments
 (0)