Skip to content

Commit e8ac44d

Browse files
authored
better setitem nontensorstack (#1367)
1 parent afcbcec commit e8ac44d

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

tensordict/_lazy.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2345,6 +2345,14 @@ def __setitem__(self, index: IndexType, value: T) -> T:
23452345
index = convert_ellipsis_to_idx(index, self.batch_size)
23462346
elif isinstance(index, (list, range)):
23472347
index = torch.as_tensor(index, device=self.device)
2348+
elif isinstance(index, (type(None), bool)) or (
2349+
isinstance(index, torch.Tensor)
2350+
and index.shape == ()
2351+
and index.dtype == torch.bool
2352+
and index.all()
2353+
):
2354+
self.unsqueeze(0).update(value)
2355+
return self
23482356

23492357
if is_tensor_collection(value) or isinstance(value, dict):
23502358
indexed_bs = _getitem_batch_size(self.batch_size, index)
@@ -2448,6 +2456,13 @@ def __getitem__(self, index: IndexType) -> Any:
24482456
return leaf.tolist()
24492457
return leaf.data
24502458
return leaf
2459+
if isinstance(index, (type(None), bool)) or (
2460+
isinstance(index, torch.Tensor)
2461+
and index.shape == ()
2462+
and index.dtype == torch.bool
2463+
and index.all()
2464+
):
2465+
return self.unsqueeze(0)
24512466
split_index = self._split_index(index)
24522467
converted_idx = split_index["index_dict"]
24532468
isinteger = split_index["isinteger"]

tensordict/_td.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,18 @@ def __setitem__(
852852
isinstance(index, tuple) and any(idx is Ellipsis for idx in index)
853853
):
854854
index = convert_ellipsis_to_idx(index, self.batch_size)
855+
# Convert index like (True,) or True to (0,) over unsqueezed self
856+
if isinstance(index, tuple) and len(index) == 1:
857+
index = index[0]
858+
if isinstance(index, (bool, type(None))) or (
859+
isinstance(index, torch.Tensor)
860+
and index.shape == ()
861+
and index.dtype == torch.bool
862+
and index.all()
863+
):
864+
with self.unsqueeze(0) as td_unsqueezed:
865+
td_unsqueezed[:] = value
866+
return
855867

856868
if isinstance(value, (TensorDictBase, dict)):
857869
indexed_bs = _getitem_batch_size(self.batch_size, index)
@@ -2531,6 +2543,17 @@ def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool):
25312543
inplace=False,
25322544
ignore_lock=True,
25332545
)
2546+
# TODO: ultimately, we want to get rid of the above logic
2547+
# dest_val = dest.maybe_to_stack()
2548+
# dest_val[idx] = value
2549+
# if dest_val is not dest:
2550+
# self._set_str(
2551+
# key,
2552+
# dest_val,
2553+
# validated=True,
2554+
# inplace=False,
2555+
# ignore_lock=True,
2556+
# )
25342557
return
25352558

25362559
if isinstance(idx, tuple) and len(idx) and isinstance(idx[0], tuple):

test/test_tensordict.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11851,6 +11851,26 @@ def test_new_empty_nontensorstack(self):
1185111851
assert isinstance(td.new_empty((4,), empty_lazy=True).get("a"), NonTensorStack)
1185211852
assert isinstance(td.new_empty((1,), empty_lazy=True).get("a"), NonTensorStack)
1185311853

11854+
def test_new_empty_setitem(self):
11855+
td = TensorDict(
11856+
a=TensorDict(
11857+
b=NonTensorStack("a", "b", "c").unsqueeze(-1), batch_size=(3,)
11858+
),
11859+
batch_size=(3,),
11860+
).to_lazystack()
11861+
tdz = td.new_zeros((4,), empty_lazy=True)
11862+
tdz[torch.tensor([True, True, False, True])] = td
11863+
assert tdz.get(("a", "b")).tolist() == [["a"], ["b"], ["a"], ["c"]]
11864+
11865+
def test_new_empty_setitem_2(self):
11866+
td = TensorDict(
11867+
a=TensorDict(b=NonTensorStack("a"), batch_size=(1,)), batch_size=(1,)
11868+
).to_lazystack()
11869+
tdz = td.new_zeros((4,), empty_lazy=True)
11870+
td["a", "b"] = "new"
11871+
tdz[torch.tensor([False, False, False, True])] = td
11872+
assert tdz["a", "b"][-1] == "new"
11873+
1185411874
def test_non_tensor_call(self):
1185511875
td0 = TensorDict({"a": 0, "b": 0})
1185611876
td1 = TensorDict({"a": 1, "b": 1})

0 commit comments

Comments
 (0)