Skip to content

Commit 6d8119c

Browse files
author
Vincent Moens
committed
[BugFix] Better list assignment in tensorclasses
ghstack-source-id: 001b0c0 Pull Request resolved: #1284
1 parent a9cc632 commit 6d8119c

File tree

5 files changed

+181
-77
lines changed

5 files changed

+181
-77
lines changed

.github/unittest/linux/scripts/run_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
1919
export MKL_THREADING_LAYER=GNU
2020
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1
2121
export TD_GET_DEFAULTS_TO_NONE=1
22+
export LIST_TO_STACK=1
2223

2324
coverage run -m pytest test/smoke_test.py -v --durations 20
2425
coverage run -m pytest --runslow --instafail -v --durations 20 --timeout 120

tensordict/_lazy.py

Lines changed: 78 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -278,37 +278,14 @@ def __init__(
278278
raise RuntimeError(
279279
f"Couldn't infer stack dim from negative value, got stack_dim={stack_dim}"
280280
)
281-
_batch_size = td0.batch_size
282-
if stack_dim > len(_batch_size):
281+
self.stack_dim = stack_dim
282+
self._reset_batch_size(td0, tensordicts, device, num_tds, strict_shape)
283+
if stack_dim > len(self.batch_size):
283284
raise RuntimeError(
284-
f"Stack dim {stack_dim} is too big for batch size {_batch_size}."
285+
f"Stack dim {stack_dim} is too big for batch size {self.batch_size}."
285286
)
286287

287-
for td in tensordicts[1:]:
288-
if not is_tensor_collection(td):
289-
raise TypeError(
290-
"Expected all inputs to be TensorDictBase instances but got "
291-
f"{type(td)} instead."
292-
)
293-
_bs = td.batch_size
294-
_device = td.device
295-
if device != _device:
296-
raise RuntimeError(f"devices differ, got {device} and {_device}")
297-
if _bs != _batch_size:
298-
if strict_shape or len(_bs) != len(_batch_size):
299-
raise RuntimeError(
300-
f"batch sizes in tensordicts differs, LazyStackedTensorDict "
301-
f"cannot be created. Got td[0].batch_size={_batch_size} "
302-
f"and td[i].batch_size={_bs}. If the length match and you wish "
303-
f"to stack these tensordicts, set strict_shape to False."
304-
)
305-
else:
306-
_batch_size = torch.Size(
307-
[s if _bs[i] == s else -1 for i, s in enumerate(_batch_size)]
308-
)
309288
self.tensordicts: list[TensorDictBase] = list(tensordicts)
310-
self.stack_dim = stack_dim
311-
self._batch_size = self._compute_batch_size(_batch_size, stack_dim, num_tds)
312289
self.hook_out = hook_out
313290
self.hook_in = hook_in
314291
if batch_size is not None and batch_size != self.batch_size and num_tds != 0:
@@ -578,6 +555,41 @@ def is_memmap(self) -> bool:
578555
)
579556
return are_memmap[0]
580557

558+
def _reset_batch_size(
559+
self,
560+
td0: TensorDictBase,
561+
tensordicts: list[TensorDictBase],
562+
device: torch.device,
563+
num_tds: int,
564+
strict_shape: bool,
565+
):
566+
_batch_size = td0.batch_size
567+
stack_dim = self.stack_dim
568+
569+
for td in tensordicts[1:]:
570+
if not is_tensor_collection(td):
571+
raise TypeError(
572+
"Expected all inputs to be TensorDictBase instances but got "
573+
f"{type(td)} instead."
574+
)
575+
_bs = td.batch_size
576+
_device = td.device
577+
if device != _device:
578+
raise RuntimeError(f"devices differ, got {device} and {_device}")
579+
if _bs != _batch_size:
580+
if strict_shape or len(_bs) != len(_batch_size):
581+
raise RuntimeError(
582+
f"batch sizes in tensordicts differs, LazyStackedTensorDict "
583+
f"cannot be created. Got td[0].batch_size={_batch_size} "
584+
f"and td[i].batch_size={_bs}. If the length match and you wish "
585+
f"to stack these tensordicts, set strict_shape to False."
586+
)
587+
else:
588+
_batch_size = torch.Size(
589+
[s if _bs[i] == s else -1 for i, s in enumerate(_batch_size)]
590+
)
591+
self._batch_size = self._compute_batch_size(_batch_size, stack_dim, num_tds)
592+
581593
@staticmethod
582594
def _compute_batch_size(
583595
batch_size: torch.Size, stack_dim: int, num_tds: int
@@ -606,7 +618,9 @@ def _set_str(
606618
) from e
607619
if not validated:
608620
value = self._validate_value(
609-
value, non_blocking=non_blocking, check_shape=not list_to_stack()
621+
value,
622+
non_blocking=non_blocking,
623+
check_shape=not (isinstance(value, list) and list_to_stack()),
610624
)
611625
validated = True
612626
if self._is_vmapped:
@@ -3147,6 +3161,42 @@ def append(self, tensordict: T) -> None:
31473161
"""
31483162
self.insert(len(self.tensordicts), tensordict)
31493163

3164+
@lock_blocked
3165+
def extend(self, tensordict: list[T] | T) -> None:
3166+
"""Extends the lazy stack with new tensordicts."""
3167+
if _is_tensor_collection(type(tensordict)):
3168+
tensordict = list(tensordict.unbind(self.stack_dim))
3169+
if any(not isinstance(tensordict, TensorDictBase) for tensordict in tensordict):
3170+
raise TypeError(
3171+
"Expected new value to be TensorDictBase instance but got "
3172+
f"{[type(tensordict) for tensordict in tensordict]} instead."
3173+
)
3174+
if self.tensordicts:
3175+
batch_size = self.tensordicts[0].batch_size
3176+
device = self.tensordicts[0].device
3177+
3178+
for _td in tensordict:
3179+
_batch_size = _td.batch_size
3180+
_device = _td.device
3181+
3182+
if device != _device:
3183+
raise ValueError(
3184+
f"Devices differ: stack has device={device}, new value has "
3185+
f"device={_device}."
3186+
)
3187+
if _batch_size != batch_size:
3188+
raise ValueError(
3189+
f"Batch sizes in tensordicts differs: stack has "
3190+
f"batch_size={batch_size}, new_value has batch_size={_batch_size}."
3191+
)
3192+
else:
3193+
batch_size = tensordict.batch_size
3194+
3195+
self.tensordicts.extend(tensordict)
3196+
3197+
N = len(self.tensordicts)
3198+
self._batch_size = self._compute_batch_size(batch_size, self.stack_dim, N)
3199+
31503200
@property
31513201
def is_locked(self) -> bool:
31523202
if self._is_locked is not None:

tensordict/tensorclass.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
IndexType,
7373
is_tensorclass,
7474
KeyDependentDefaultDict,
75+
list_to_stack,
7576
set_capture_non_tensor_stack,
7677
)
7778
from torch import multiprocessing as mp, Tensor
@@ -363,6 +364,7 @@ def __subclasscheck__(self, subclass):
363364
"expand_as",
364365
"expm1",
365366
"expm1_",
367+
"extend",
366368
"fill_",
367369
"filter_empty_",
368370
"filter_non_tensor_data",
@@ -1025,6 +1027,9 @@ def __torch_function__(
10251027
_wrap_td_method(method_name, copy_non_tensor=True),
10261028
)
10271029

1030+
# if not hasattr(cls, "batch_size") and "batch_size" not in expected_keys:
1031+
# cls.batch_size = property(_batch_size, _batch_size_setter)
1032+
10281033
cls.__enter__ = __enter__
10291034
cls.__exit__ = __exit__
10301035

@@ -1080,6 +1085,12 @@ def __torch_function__(
10801085
return cls
10811086

10821087

1088+
# def _batch_size(self):
1089+
# return self.__dict__["_tensordict"]._batch_size
1090+
# def _batch_size_setter(self, value):
1091+
# self.__dict__["_tensordict"].batch_size = value
1092+
1093+
10831094
def _arg_to_tensordict(arg):
10841095
# if arg is a tensorclass or sequence of tensorclasses, extract the underlying
10851096
# tensordicts and return those instead
@@ -2347,6 +2358,9 @@ def _is_castable(datatype):
23472358
)
23482359
):
23492360
return set_tensor()
2361+
elif issubclass(value_type, list) and list_to_stack():
2362+
# set() will take care of casting to non tensor
2363+
non_tensor = False
23502364
else:
23512365
non_tensor = True
23522366

test/test_tensorclass.py

Lines changed: 83 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
LazyStackedTensorDict,
3636
MemoryMappedTensor,
3737
set_capture_non_tensor_stack,
38+
set_list_to_stack,
3839
tensorclass,
3940
TensorClass,
4041
TensorDict,
@@ -1032,19 +1033,28 @@ class MyDataParent:
10321033
assert data.y.v == "test_nested"
10331034
assert data.y.batch_size == torch.Size(batch_size)
10341035

1035-
def test_indexing(self):
1036-
@tensorclass
1037-
class MyDataNested:
1038-
X: torch.Tensor
1039-
z: list
1040-
y: "MyDataNested" = None
1041-
1042-
X = torch.ones(3, 4, 5)
1043-
z = ["a", "b", "c"]
1044-
batch_size = [3, 4]
1045-
data_nest = MyDataNested(X=X, z=z, batch_size=batch_size)
1046-
data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size)
1036+
@pytest.mark.parametrize("list_to_stack", [True, False])
1037+
def test_indexing(self, list_to_stack):
1038+
with set_list_to_stack(list_to_stack):
10471039

1040+
@tensorclass
1041+
class MyDataNested:
1042+
X: torch.Tensor
1043+
z: list
1044+
y: "MyDataNested" = None
1045+
1046+
X = torch.ones(3, 4, 5)
1047+
z = ["a", "b", "c"]
1048+
batch_size = [3, 4]
1049+
with (
1050+
pytest.raises(RuntimeError, match="batch dimension mismatch")
1051+
if list_to_stack
1052+
else contextlib.nullcontext()
1053+
):
1054+
data_nest = MyDataNested(X=X, z=z, batch_size=batch_size)
1055+
data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size)
1056+
if list_to_stack:
1057+
return
10481058
assert data[:2].batch_size == torch.Size([2, 4])
10491059
assert data[:2].X.shape == torch.Size([2, 4, 5])
10501060
assert (data[:2].X == X[:2]).all()
@@ -1462,6 +1472,21 @@ class Data:
14621472
assert (data_select == 1).all()
14631473
assert "a" in data_select._tensordict
14641474

1475+
@set_list_to_stack(True)
1476+
def test_set_list_in_constructor(self):
1477+
obj = MyTensorClass(
1478+
a=["a string", "another string"],
1479+
b=[torch.randn(3), torch.zeros(3)],
1480+
c="smth completly different",
1481+
batch_size=2,
1482+
)
1483+
assert obj.shape == (2,)
1484+
assert obj[0].a == "a string"
1485+
assert obj[1].a == "another string"
1486+
assert (obj[0].b != 0).all()
1487+
assert (obj[1].b == 0).all()
1488+
assert obj.c == obj[0].c
1489+
14651490
def test_set_dict(self):
14661491
@tensorclass(autocast=True)
14671492
class MyClass:
@@ -1540,7 +1565,8 @@ class MyDataParent:
15401565
# ensure optional fields are writable
15411566
data.k = torch.zeros(3, 4, 5)
15421567

1543-
def test_setitem(self):
1568+
@pytest.mark.parametrize("list_to_stack", [True, False])
1569+
def test_setitem(self, list_to_stack):
15441570
data = MyData(
15451571
X=torch.ones(3, 4, 5),
15461572
y=torch.zeros(3, 4, 5),
@@ -1599,26 +1625,34 @@ class MyDataNested:
15991625
X = torch.randn(3, 4, 5)
16001626
z = ["a", "b", "c"]
16011627
batch_size = [3, 4]
1602-
data_nest = MyDataNested(X=X, z=z, batch_size=batch_size)
1603-
data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size)
1604-
X2 = torch.ones(3, 4, 5)
1605-
data_nest2 = MyDataNested(X=X2, z=z, batch_size=batch_size)
1606-
data2 = MyDataNested(X=X2, y=data_nest2, z=z, batch_size=batch_size)
1607-
data[:2] = data2[:2].clone()
1608-
assert (data[:2].X == data2[:2].X).all()
1609-
assert (data[:2].y.X == data2[:2].y.X).all()
1610-
assert data[:2].z == z
1611-
1612-
# Negative Scenario
1613-
data3 = MyDataNested(X=X2, y=data_nest2, z=["e", "f"], batch_size=batch_size)
1614-
data[:2] = data3[:2]
1615-
assert data[:2].z == data3[:2]._get_str("z", None).tolist()
1628+
with set_list_to_stack(list_to_stack), (
1629+
pytest.raises(RuntimeError, match="batch dimension mismatch")
1630+
if list_to_stack
1631+
else contextlib.nullcontext()
1632+
):
1633+
data_nest = MyDataNested(X=X, z=z, batch_size=batch_size)
1634+
data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size)
1635+
X2 = torch.ones(3, 4, 5)
1636+
data_nest2 = MyDataNested(X=X2, z=z, batch_size=batch_size)
1637+
data2 = MyDataNested(X=X2, y=data_nest2, z=z, batch_size=batch_size)
1638+
data[:2] = data2[:2].clone()
1639+
assert (data[:2].X == data2[:2].X).all()
1640+
assert (data[:2].y.X == data2[:2].y.X).all()
1641+
assert data[:2].z == z
1642+
1643+
# Negative Scenario
1644+
data3 = MyDataNested(
1645+
X=X2, y=data_nest2, z=["e", "f"], batch_size=batch_size
1646+
)
1647+
data[:2] = data3[:2]
1648+
assert data[:2].z == data3[:2]._get_str("z", None).tolist()
16161649

16171650
@pytest.mark.parametrize(
16181651
"broadcast_type",
16191652
["scalar", "tensor", "tensordict", "maptensor"],
16201653
)
1621-
def test_setitem_broadcast(self, broadcast_type):
1654+
@pytest.mark.parametrize("list_to_stack", [True, False])
1655+
def test_setitem_broadcast(self, broadcast_type, list_to_stack):
16221656
@tensorclass
16231657
class MyDataNested:
16241658
X: torch.Tensor
@@ -1628,22 +1662,27 @@ class MyDataNested:
16281662
X = torch.ones(3, 4, 5)
16291663
z = ["a", "b", "c"]
16301664
batch_size = [3, 4]
1631-
data_nest = MyDataNested(X=X, z=z, batch_size=batch_size)
1632-
data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size)
1633-
1634-
if broadcast_type == "scalar":
1635-
val = 0
1636-
elif broadcast_type == "tensor":
1637-
val = torch.zeros(4, 5)
1638-
elif broadcast_type == "tensordict":
1639-
val = TensorDict({"X": torch.zeros(2, 4, 5)}, batch_size=[2, 4])
1640-
elif broadcast_type == "maptensor":
1641-
val = MemoryMappedTensor.from_tensor(torch.zeros(4, 5))
1642-
1643-
data[:2] = val
1644-
assert (data[:2] == 0).all()
1645-
assert (data.X[:2] == 0).all()
1646-
assert (data.y.X[:2] == 0).all()
1665+
with set_list_to_stack(list_to_stack), (
1666+
pytest.raises(RuntimeError, match="batch dimension mismatch")
1667+
if list_to_stack
1668+
else contextlib.nullcontext()
1669+
):
1670+
data_nest = MyDataNested(X=X, z=z, batch_size=batch_size)
1671+
data = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size)
1672+
1673+
if broadcast_type == "scalar":
1674+
val = 0
1675+
elif broadcast_type == "tensor":
1676+
val = torch.zeros(4, 5)
1677+
elif broadcast_type == "tensordict":
1678+
val = TensorDict({"X": torch.zeros(2, 4, 5)}, batch_size=[2, 4])
1679+
elif broadcast_type == "maptensor":
1680+
val = MemoryMappedTensor.from_tensor(torch.zeros(4, 5))
1681+
1682+
data[:2] = val
1683+
assert (data[:2] == 0).all()
1684+
assert (data.X[:2] == 0).all()
1685+
assert (data.y.X[:2] == 0).all()
16471686

16481687
def test_setitem_memmap(self):
16491688
# regression test PR #203

test/test_tensordict.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11864,16 +11864,15 @@ def test_shared_memmap_single(self, pair, strategy, update, tmpdir):
1186411864

1186511865
@staticmethod
1186611866
def _run_worker(td, val1, update):
11867+
set_list_to_stack(True).set()
1186711868
# Update in place
1186811869
if update == "setitem":
11869-
td["val"] = val1
11870+
td["val"] = NonTensorData(val1)
1187011871
elif update == "update_":
11871-
td.get("val").update_(
11872-
NonTensorData(data=val1, batch_size=[]), non_blocking=False
11873-
)
11872+
td.get("val").update_(NonTensorData(data=val1), non_blocking=False)
1187411873
elif update == "update-inplace":
1187511874
td.get("val").update(
11876-
NonTensorData(data=val1, batch_size=[]),
11875+
NonTensorData(data=val1),
1187711876
inplace=True,
1187811877
non_blocking=False,
1187911878
)
@@ -11884,6 +11883,7 @@ def _run_worker(td, val1, update):
1188411883
assert td["val"] == val1
1188511884

1188611885
@pytest.mark.slow
11886+
@set_list_to_stack(True)
1188711887
@pytest.mark.parametrize("pair", PAIRS)
1188811888
@pytest.mark.parametrize("strategy", ["shared", "memmap"])
1188911889
@pytest.mark.parametrize("update", ["update_", "update-inplace"])

0 commit comments

Comments
 (0)