@@ -827,7 +827,7 @@ def set(
827827 if not self .initialized :
828828 if not isinstance (cursor , INT_CLASSES ):
829829 if is_tensor_collection (data ):
830- self ._init (data [ 0 ])
830+ self ._init (data , shape = data . shape [ 1 : ])
831831 else :
832832 self ._init (tree_map (lambda x : x [0 ], data ))
833833 else :
@@ -873,7 +873,7 @@ def set( # noqa: F811
873873 )
874874 if not self .initialized :
875875 if not isinstance (cursor , INT_CLASSES ):
876- self ._init (data [ 0 ])
876+ self ._init (data , shape = data . shape [ 1 : ])
877877 else :
878878 self ._init (data )
879879 if not isinstance (cursor , (* INT_CLASSES , slice )):
@@ -993,6 +993,15 @@ class LazyTensorStorage(TensorStorage):
993993 Defaults to ``False``.
994994 consolidated (bool, optional): if ``True``, the storage will be consolidated after
995995 its first expansion. Defaults to ``False``.
996+ empty_lazy (bool, optional): if ``True``, any lazy tensordict in the first tensordict
997+ passed to the storage will be emptied of its content. This can be used to store
998+ ragged data or content with exclusive keys (e.g., when some but not all environments
999+ provide extra data to be stored in the buffer).
1000+ Setting `empty_lazy` to `True` requires :meth:`~.extend` to be called first (a call to `add`
1001+ will result in an exception).
1002+ Recall that data stored in lazy stacks is not stored contiguously in memory: indexing can be
1003+ slower than contiguous data and serialization is more hazardous. Use with caution!
1004+ Defaults to ``False``.
9961005
9971006 Examples:
9981007 >>> data = TensorDict({
@@ -1054,6 +1063,7 @@ def __init__(
10541063 ndim : int = 1 ,
10551064 compilable : bool = False ,
10561065 consolidated : bool = False ,
1066+ empty_lazy : bool = False ,
10571067 ):
10581068 super ().__init__ (
10591069 storage = None ,
@@ -1062,11 +1072,13 @@ def __init__(
10621072 ndim = ndim ,
10631073 compilable = compilable ,
10641074 )
1075+ self .empty_lazy = empty_lazy
10651076 self .consolidated = consolidated
10661077
10671078 def _init (
10681079 self ,
10691080 data : TensorDictBase | torch .Tensor | PyTree , # noqa: F821
1081+ shape : torch .Size | None = None ,
10701082 ) -> None :
10711083 if not self ._compilable :
10721084 # TODO: Investigate why this seems to have a performance impact with
@@ -1087,8 +1099,21 @@ def max_size_along_dim0(data_shape):
10871099
10881100 if is_tensor_collection (data ):
10891101 out = data .to (self .device )
1090- out : TensorDictBase = torch .empty_like (
1091- out .expand (max_size_along_dim0 (data .shape ))
1102+ if self .empty_lazy :
1103+ if shape is None :
1104+ # shape is None in add
1105+ raise RuntimeError (
1106+ "Make sure you have called `extend` and not `add` first when setting `empty_lazy=True`."
1107+ )
1108+ out : TensorDictBase = torch .empty_like (
1109+ out .expand (max_size_along_dim0 (data .shape ))
1110+ )
1111+ elif shape is None :
1112+ shape = data .shape
1113+ else :
1114+ out = out [0 ]
1115+ out : TensorDictBase = out .new_empty (
1116+ max_size_along_dim0 (shape ), empty_lazy = self .empty_lazy
10921117 )
10931118 if self .consolidated :
10941119 out = out .consolidate ()
@@ -1286,7 +1311,9 @@ def load_state_dict(self, state_dict):
12861311 self .initialized = state_dict ["initialized" ]
12871312 self ._len = state_dict ["_len" ]
12881313
1289- def _init (self , data : TensorDictBase | torch .Tensor ) -> None :
1314+ def _init (
1315+ self , data : TensorDictBase | torch .Tensor , * , shape : torch .Size | None = None
1316+ ) -> None :
12901317 torchrl_logger .debug ("Creating a MemmapStorage..." )
12911318 if self .device == "auto" :
12921319 self .device = data .device
@@ -1304,8 +1331,14 @@ def max_size_along_dim0(data_shape):
13041331 return (self .max_size , * data_shape )
13051332
13061333 if is_tensor_collection (data ):
1334+ if shape is None :
1335+ # Within add()
1336+ shape = data .shape
1337+ else :
1338+ # Get the first element - we don't care about empty_lazy in memmap storages
1339+ data = data [0 ]
13071340 out = data .clone ().to (self .device )
1308- out = out .expand (max_size_along_dim0 (data . shape ))
1341+ out = out .expand (max_size_along_dim0 (shape ))
13091342 out = out .memmap_like (prefix = self .scratch_dir , existsok = self .existsok )
13101343 if torchrl_logger .isEnabledFor (logging .DEBUG ):
13111344 for key , tensor in sorted (
0 commit comments