Skip to content

Commit c8bfda2

Browse files
authored
[BugFix] Fix update with update_batch_size when source is TD and dest is LTD (#1371)
1 parent e8ac44d commit c8bfda2

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

tensordict/_lazy.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3085,13 +3085,30 @@ def update(
30853085
if input_dict_or_td.ndim <= self_upd.stack_dim or input_dict_or_td.batch_size[
30863086
self_upd.stack_dim
30873087
] != len(self_upd.tensordicts):
3088+
# We receive a tensordict with a different batch-size than self.
3089+
# If update_batch_size is True, we can just convert the input tensordict to a lazy stack and update.
3090+
# This will change self, most likely removing a bunch of sub-tensordicts but we're good because the
3091+
# user is interested in modifying the batch-size.
3092+
# If update_batch_size is False, we need to try to change the batch-size of the input tensordict.
3093+
# That can only be done in restricted cases, so we raise an error if the batch-size of self (which must
3094+
# remain unchanged) is incompatible with the content of the input tensordict.
3095+
if update_batch_size:
3096+
return self.update(
3097+
input_dict_or_td.to_lazystack(self.stack_dim),
3098+
clone=clone,
3099+
keys_to_update=keys_to_update,
3100+
non_blocking=non_blocking,
3101+
is_leaf=is_leaf,
3102+
update_batch_size=update_batch_size,
3103+
**kwargs,
3104+
)
3105+
# if the batch-size does not permit unbinding, let's first try to reset the batch-size.
3106+
input_dict_or_td = input_dict_or_td.copy()
3107+
batch_size = self_upd.batch_size
3108+
if self_upd.hook_out is not None:
3109+
batch_size = list(batch_size)
3110+
batch_size.insert(self_upd.stack_dim, len(self_upd.tensordicts))
30883111
try:
3089-
# if the batch-size does not permit unbinding, let's first try to reset the batch-size.
3090-
input_dict_or_td = input_dict_or_td.copy()
3091-
batch_size = self_upd.batch_size
3092-
if self_upd.hook_out is not None:
3093-
batch_size = list(batch_size)
3094-
batch_size.insert(self_upd.stack_dim, len(self_upd.tensordicts))
30953112
input_dict_or_td.batch_size = batch_size
30963113
except RuntimeError as err:
30973114
raise ValueError(

test/test_tensordict.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10246,6 +10246,19 @@ def test_update_with_lazy(self):
1024610246
assert (td_void.get(("parent", "a", "b"))[1].get("d") == 0).all() # unaffected
1024710247
assert (td_void.get(("parent", "a", "b")).get("e") == 0).all() # unaffected
1024810248

10249+
@pytest.mark.parametrize("source_is_lazy", [True, False])
10250+
def test_update_batch_size(self, source_is_lazy):
10251+
td = TensorDict(
10252+
a=torch.zeros(3, 4), b=torch.randn(3, 4), batch_size=[3, 4]
10253+
).to_lazystack(0)
10254+
td2 = TensorDict(a=torch.ones(2, 4), b=torch.randn(2, 4), batch_size=[2, 4])
10255+
if source_is_lazy:
10256+
td2 = td2.to_lazystack(0)
10257+
td.update(td2, update_batch_size=True)
10258+
assert td.batch_size == td2.batch_size
10259+
assert td.batch_size == (2, 4)
10260+
assert td.batch_size == td2.batch_size
10261+
1024910262

1025010263
@pytest.mark.skipif(
1025110264
not _has_torchsnapshot, reason=f"torchsnapshot not found: err={TORCHSNAPSHOT_ERR}"

0 commit comments

Comments
 (0)