Skip to content

Commit c94a036

Browse files
committed
[BugFix] Uneven splits (#1376)
1 parent 96a3e68 commit c94a036

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

tensordict/_td.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,8 +1751,8 @@ def split(
17511751
max_size = batch_size[dim]
17521752
if isinstance(split_size, int):
17531753
segments = _create_segments_from_int(split_size, max_size)
1754-
chunks = -(self.batch_size[dim] // -split_size)
1755-
splits = {k: v.chunk(chunks, dim) for k, v in self.items()}
1754+
splits = [end - start for start, end in segments]
1755+
splits = {k: v.split(splits, dim) for k, v in self.items()}
17561756
elif isinstance(split_size, (list, tuple)):
17571757
if len(split_size) == 0:
17581758
raise RuntimeError("Insufficient number of elements in split_size.")

test/test_tensordict.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2665,6 +2665,38 @@ def test_split_keys(self):
26652665
assert td is td3
26662666
assert "d" not in td
26672667

2668+
@pytest.mark.parametrize("sign", ["plus", "minus"])
2669+
def test_split_uneven(self, sign):
2670+
a = torch.arange(6).unsqueeze(-1).expand(6, 3)
2671+
b = torch.arange(18).view(6, 3)
2672+
c = torch.arange(36).view(6, 3, 2)
2673+
td = TensorDict({"a": a, "b": b, "c": c}, [6, 3])
2674+
2675+
if sign == "plus":
2676+
tds = td.split(5, 0)
2677+
else:
2678+
tds = td.split(5, -2)
2679+
assert tds[0].shape == torch.Size([5, 3])
2680+
assert tds[1].shape == torch.Size([1, 3])
2681+
assert tds[0]["a"].shape == torch.Size([5, 3])
2682+
assert tds[1]["a"].shape == torch.Size([1, 3])
2683+
assert tds[0]["b"].shape == torch.Size([5, 3])
2684+
assert tds[1]["b"].shape == torch.Size([1, 3])
2685+
assert tds[0]["c"].shape == torch.Size([5, 3, 2])
2686+
assert tds[1]["c"].shape == torch.Size([1, 3, 2])
2687+
if sign == "plus":
2688+
tds = td.split(2, 1)
2689+
else:
2690+
tds = td.split(2, -1)
2691+
assert tds[0].shape == torch.Size([6, 2])
2692+
assert tds[1].shape == torch.Size([6, 1])
2693+
assert tds[0]["a"].shape == torch.Size([6, 2])
2694+
assert tds[1]["a"].shape == torch.Size([6, 1])
2695+
assert tds[0]["b"].shape == torch.Size([6, 2])
2696+
assert tds[1]["b"].shape == torch.Size([6, 1])
2697+
assert tds[0]["c"].shape == torch.Size([6, 2, 2])
2698+
assert tds[1]["c"].shape == torch.Size([6, 1, 2])
2699+
26682700
def test_setitem_nested(self):
26692701
tensor = torch.randn(4, 5, 6, 7)
26702702
tensor2 = torch.ones(4, 5, 6, 7)

0 commit comments

Comments
 (0)