@@ -414,10 +414,12 @@ def _stack(
414414 out : T | None = None ,
415415 strict : bool = False ,
416416 contiguous : bool = False ,
417- maybe_dense_stack : bool = False ,
417+ maybe_dense_stack : bool | None = None ,
418418) -> T :
419419 if not len (list_of_tensordicts ):
420420 raise RuntimeError ("list_of_tensordicts cannot be empty" )
421+ if maybe_dense_stack is None :
422+ maybe_dense_stack = lazy_legacy ()
421423 is_tc = any (is_tensorclass (td ) for td in list_of_tensordicts )
422424 if all (is_non_tensor (td ) for td in list_of_tensordicts ):
423425 from tensordict .tensorclass import NonTensorData
@@ -457,7 +459,11 @@ def _stack(
457459 if not _lazy_legacy and not contiguous :
458460 if maybe_dense_stack :
459461 with set_lazy_legacy (True ):
460- return _stack (list_of_tensordicts , dim = dim )
462+ return _stack (
463+ list_of_tensordicts ,
464+ dim = dim ,
465+ maybe_dense_stack = maybe_dense_stack ,
466+ )
461467 else :
462468 raise RuntimeError (
463469 "The sets of keys in the tensordicts to stack are exclusive. "
@@ -490,7 +496,11 @@ def _stack(
490496 dim = dim - 1
491497 return LazyStackedTensorDict (
492498 * [
493- _stack (list (subtds ), dim = dim )
499+ _stack (
500+ list (subtds ),
501+ dim = dim ,
502+ maybe_dense_stack = maybe_dense_stack ,
503+ )
494504 for subtds in _zip_strict (
495505 * [td .tensordicts for td in list_of_tensordicts ]
496506 )
@@ -540,7 +550,11 @@ def _stack(
540550 # Nested tensors will require a lazy stack
541551 if maybe_dense_stack :
542552 with set_lazy_legacy (True ):
543- return _stack (list_of_tensordicts , dim = dim )
553+ return _stack (
554+ list_of_tensordicts ,
555+ dim = dim ,
556+ maybe_dense_stack = maybe_dense_stack ,
557+ )
544558 else :
545559 raise RuntimeError (
546560 f"The shapes of the tensors to stack is incompatible: { new_tensor_shape } vs { tensor_shape } for key { key } ."
0 commit comments