Skip to content

Commit 1e32195

Browse files
author
Vincent Moens
committed
[BugFix] Propagate maybe_dense_stack in _stack
ghstack-source-id: a1cb1de Pull Request resolved: #1036
1 parent d147be4 commit 1e32195

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

tensordict/_torch_func.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -414,10 +414,12 @@ def _stack(
414414
out: T | None = None,
415415
strict: bool = False,
416416
contiguous: bool = False,
417-
maybe_dense_stack: bool = False,
417+
maybe_dense_stack: bool | None = None,
418418
) -> T:
419419
if not len(list_of_tensordicts):
420420
raise RuntimeError("list_of_tensordicts cannot be empty")
421+
if maybe_dense_stack is None:
422+
maybe_dense_stack = lazy_legacy()
421423
is_tc = any(is_tensorclass(td) for td in list_of_tensordicts)
422424
if all(is_non_tensor(td) for td in list_of_tensordicts):
423425
from tensordict.tensorclass import NonTensorData
@@ -457,7 +459,11 @@ def _stack(
457459
if not _lazy_legacy and not contiguous:
458460
if maybe_dense_stack:
459461
with set_lazy_legacy(True):
460-
return _stack(list_of_tensordicts, dim=dim)
462+
return _stack(
463+
list_of_tensordicts,
464+
dim=dim,
465+
maybe_dense_stack=maybe_dense_stack,
466+
)
461467
else:
462468
raise RuntimeError(
463469
"The sets of keys in the tensordicts to stack are exclusive. "
@@ -490,7 +496,11 @@ def _stack(
490496
dim = dim - 1
491497
return LazyStackedTensorDict(
492498
*[
493-
_stack(list(subtds), dim=dim)
499+
_stack(
500+
list(subtds),
501+
dim=dim,
502+
maybe_dense_stack=maybe_dense_stack,
503+
)
494504
for subtds in _zip_strict(
495505
*[td.tensordicts for td in list_of_tensordicts]
496506
)
@@ -540,7 +550,11 @@ def _stack(
540550
# Nested tensors will require a lazy stack
541551
if maybe_dense_stack:
542552
with set_lazy_legacy(True):
543-
return _stack(list_of_tensordicts, dim=dim)
553+
return _stack(
554+
list_of_tensordicts,
555+
dim=dim,
556+
maybe_dense_stack=maybe_dense_stack,
557+
)
544558
else:
545559
raise RuntimeError(
546560
f"The shapes of the tensors to stack is incompatible: {new_tensor_shape} vs {tensor_shape} for key {key}."

0 commit comments

Comments
 (0)