File tree Expand file tree Collapse file tree 2 files changed +15
-2
lines changed Expand file tree Collapse file tree 2 files changed +15
-2
lines changed Original file line number Diff line number Diff line change @@ -5042,7 +5042,8 @@ def assign(
50425042 cls = type(value)
50435043 if issubclass(cls, torch.Tensor):
50445044 pass
5045- elif _is_non_tensor(cls):
5045+ # We want to skip NonTensorStacks
5046+ elif _is_non_tensor(cls) and not issubclass(cls, TensorDictBase):
50465047 if requires_metadata:
50475048 metadata_dict["non_tensors"][key] = (
50485049 value.data,
@@ -5410,7 +5411,8 @@ def _view_and_pad(tensor):
54105411 if non_blocking and device.type != "cuda":
54115412 # sync if needed
54125413 self._sync_all()
5413- torch.cat(items, out=storage)
5414+ if items:
5415+ torch.cat(items, out=storage)
54145416 for v, (k, oldv) in _zip_strict(
54155417 storage.split(flat_size), list(flat_dict.items())
54165418 ):
Original file line number Diff line number Diff line change @@ -11335,6 +11335,17 @@ def test_stack(self, non_tensor_data):
1133511335 LazyStackedTensorDict ,
1133611336 )
1133711337
11338+ def test_stack_consolidate (self ):
11339+ td = torch .stack (
11340+ [
11341+ TensorDict (a = "a string" , b = "b string" ),
11342+ TensorDict (a = "another string" , b = "bnother string" ),
11343+ ]
11344+ )
11345+ tdc = td .consolidate ()
11346+ assert (tdc == td ).all ()
11347+ assert tdc ["a" ] == ["a string" , "another string" ]
11348+
1133811349 def test_assign_non_tensor (self ):
1133911350 data = TensorDict ({}, [1 , 10 ])
1134011351
You can’t perform that action at this time.
0 commit comments