Skip to content

Commit 0bb94c0

Browse files
nkkarpovvmoens
andauthored
[BugFix] Add check for split_size in TensorDict.split (#1370)
Co-authored-by: vmoens <vincentmoens@gmail.com>
1 parent c8bfda2 commit 0bb94c0

File tree

3 files changed

+57
-59
lines changed

3 files changed

+57
-59
lines changed

tensordict/_td.py

Lines changed: 15 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
_BatchedUninitializedParameter,
4949
_check_inbuild,
5050
_clone_value,
51+
_create_segments_from_int,
52+
_create_segments_from_list,
5153
_get_item,
5254
_get_leaf_tensordict,
5355
_get_shape_from_args,
@@ -1750,71 +1752,28 @@ def split(
17501752
# we must use slices to keep the storage of the tensors
17511753
WRONG_TYPE = "split(): argument 'split_size' must be int or list of ints"
17521754
batch_size = self.batch_size
1753-
batch_sizes = []
17541755
dim = _maybe_correct_neg_dim(dim, batch_size)
17551756
max_size = batch_size[dim]
17561757
if isinstance(split_size, int):
1757-
idx0 = 0
1758-
idx1 = min(max_size, split_size)
1759-
batch_sizes.append(
1760-
torch.Size(
1761-
tuple(
1762-
d if i != dim else idx1 - idx0 for i, d in enumerate(batch_size)
1763-
)
1764-
)
1765-
)
1766-
while idx1 < max_size:
1767-
idx0 = idx1
1768-
idx1 = min(max_size, idx1 + split_size)
1769-
batch_sizes.append(
1770-
torch.Size(
1771-
tuple(
1772-
d if i != dim else idx1 - idx0
1773-
for i, d in enumerate(batch_size)
1774-
)
1775-
)
1776-
)
1758+
segments = _create_segments_from_int(split_size, max_size)
1759+
chunks = -(self.batch_size[dim] // -split_size)
1760+
splits = {k: v.chunk(chunks, dim) for k, v in self.items()}
17771761
elif isinstance(split_size, (list, tuple)):
17781762
if len(split_size) == 0:
17791763
raise RuntimeError("Insufficient number of elements in split_size.")
1780-
try:
1781-
idx0 = 0
1782-
idx1 = split_size[0]
1783-
batch_sizes.append(
1784-
torch.Size(
1785-
tuple(
1786-
d if i != dim else idx1 - idx0
1787-
for i, d in enumerate(batch_size)
1788-
)
1789-
)
1790-
)
1791-
for idx in split_size[1:]:
1792-
idx0 = idx1
1793-
idx1 = min(max_size, idx1 + idx)
1794-
batch_sizes.append(
1795-
torch.Size(
1796-
tuple(
1797-
d if i != dim else idx1 - idx0
1798-
for i, d in enumerate(batch_size)
1799-
)
1800-
)
1801-
)
1802-
except TypeError:
1764+
if not all(isinstance(x, int) for x in split_size):
18031765
raise TypeError(WRONG_TYPE)
1804-
1805-
if idx1 < batch_size[dim]:
1806-
raise RuntimeError(
1807-
f"Split method expects split_size to sum exactly to {self.batch_size[dim]} (tensor's size at dimension {dim}), but got split_size={split_size}"
1808-
)
1766+
splits = {k: v.split(split_size, dim) for k, v in self.items()}
1767+
segments = _create_segments_from_list(split_size, max_size)
18091768
else:
18101769
raise TypeError(WRONG_TYPE)
18111770
names = self._maybe_names()
1812-
# Use chunk instead of split to account for nested tensors if possible
1813-
if isinstance(split_size, int):
1814-
chunks = -(self.batch_size[dim] // -split_size)
1815-
splits = {k: v.chunk(chunks, dim) for k, v in self.items()}
1816-
else:
1817-
splits = {k: v.split(split_size, dim) for k, v in self.items()}
1771+
batch_sizes = [
1772+
torch.Size(
1773+
tuple(d if i != dim else end - start for i, d in enumerate(batch_size))
1774+
)
1775+
for start, end in segments
1776+
]
18181777
splits = [
18191778
{k: v[ss] for k, v in splits.items()} for ss in range(len(batch_sizes))
18201779
]
@@ -2184,7 +2143,6 @@ def from_dict_instance(
21842143
batch_dims=None,
21852144
names=None,
21862145
):
2187-
21882146
if batch_dims is not None and batch_size is not None:
21892147
raise ValueError(
21902148
"Cannot pass both batch_size and batch_dims to `from_dict`."
@@ -2274,7 +2232,7 @@ def batch_dims(self) -> int:
22742232
@batch_dims.setter
22752233
def batch_dims(self, value: int) -> None:
22762234
raise RuntimeError(
2277-
f"Setting batch dims on {type(self).__name__} instances is " f"not allowed."
2235+
f"Setting batch dims on {type(self).__name__} instances is not allowed."
22782236
)
22792237

22802238
def _has_names(self):
@@ -2763,7 +2721,6 @@ def _memmap_(
27632721
share_non_tensor,
27642722
existsok,
27652723
) -> T:
2766-
27672724
if prefix is not None:
27682725
prefix = Path(prefix)
27692726
if not prefix.exists():
@@ -2806,7 +2763,6 @@ def _memmap_(
28062763
)
28072764
continue
28082765
else:
2809-
28102766
if executor is None:
28112767
_populate_memmap(
28122768
dest=dest,

tensordict/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import concurrent.futures
99
import functools
1010
import inspect
11+
import itertools
1112
import logging
1213

1314
import math
@@ -3050,3 +3051,36 @@ def _check_is_unflatten(new_shape, old_shape, return_flatten_dim=False):
30503051
# j = len(new_shape) - j - 1
30513052
return out, (i, j)
30523053
return out
3054+
3055+
3056+
def _create_segments_from_int(split_size, max_size):
3057+
if split_size <= 0:
3058+
raise RuntimeError(
3059+
f"split_size must be a positive integer, but got {split_size}."
3060+
)
3061+
splits = [
3062+
(start, min(start + split_size, max_size))
3063+
for start in range(0, max_size, split_size)
3064+
]
3065+
return splits
3066+
3067+
3068+
def _create_segments_from_list(
3069+
split_size: list[int] | tuple[int],
3070+
max_size: int,
3071+
):
3072+
splits = [
3073+
(start, min(start + size, max_size))
3074+
for start, size in zip(
3075+
[0] + list(itertools.accumulate(split_size[:-1])),
3076+
split_size,
3077+
)
3078+
]
3079+
total_split_size = sum(split_size)
3080+
if total_split_size != max_size:
3081+
raise RuntimeError(
3082+
f"Split method expects split_size to sum exactly to {max_size}, "
3083+
f"but got sum({split_size}) = {total_split_size}"
3084+
)
3085+
3086+
return splits

test/test_tensordict.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2744,6 +2744,14 @@ def test_split_with_invalid_arguments(self):
27442744
td.split(1, 2)
27452745
with pytest.raises(IndexError, match="Incompatible dim"):
27462746
td.split(1, -3)
2747+
with pytest.raises(
2748+
RuntimeError, match="split_size must be a positive integer, but got 0."
2749+
):
2750+
td.split(0, -1)
2751+
with pytest.raises(
2752+
RuntimeError, match="split_size must be a positive integer, but got -1."
2753+
):
2754+
td.split(-1, -1)
27472755

27482756
def test_split_with_negative_dim(self):
27492757
td = TensorDict(

0 commit comments

Comments
 (0)