Skip to content

Commit 4bc320f

Browse files
author
Tim Joseph
committed
fix(container): improve 'too many indices' error handling
Reordered index validation checks in `_normalize_idx` to consistently raise `IndexError` when too many dimensions are provided. Previously, an index without Ellipses may have skipped this check. Updated the error message from "array is X-dimensional" to "container is X-dimensional" for better clarity and consistency. Adjusted existing tests to expect `IndexError` instead of `RuntimeError` and added a new test case to explicitly verify this error scenario.
1 parent c21b636 commit 4bc320f

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

src/tensorcontainer/tensor_container.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -373,32 +373,31 @@ def transform_ellipsis_index(self, shape: torch.Size, idx: tuple) -> tuple:
373373
Transforms an indexing tuple with an ellipsis into an equivalent one without it.
374374
...
375375
"""
376-
if Ellipsis not in idx:
377-
return idx
378-
379-
ellipsis_count = 0
380-
for item in idx:
381-
if item is Ellipsis:
382-
ellipsis_count += 1
383-
if ellipsis_count > 1:
384-
raise IndexError("an index can only have a single ellipsis ('...')")
385-
386-
ellipsis_pos = idx.index(Ellipsis)
387-
388376
# Count how many items in the index "consume" an axis from the original shape.
389377
# `None` adds a new axis, so it's not counted.
390378
num_consuming_indices = sum(
391379
self.get_number_of_consuming_dims(item) for item in idx
392380
)
393381

394382
rank = len(shape)
395-
396383
if num_consuming_indices > rank:
397384
raise IndexError(
398-
f"too many indices for array: array is {rank}-dimensional, "
385+
f"too many indices for container: container is {rank}-dimensional, "
399386
f"but {num_consuming_indices} were indexed"
400387
)
401388

389+
if Ellipsis not in idx:
390+
return idx
391+
392+
ellipsis_count = 0
393+
for item in idx:
394+
if item is Ellipsis:
395+
ellipsis_count += 1
396+
if ellipsis_count > 1:
397+
raise IndexError("an index can only have a single ellipsis ('...')")
398+
399+
ellipsis_pos = idx.index(Ellipsis)
400+
402401
# Calculate slices needed based on the consuming indices
403402
num_slices_to_add = rank - num_consuming_indices
404403

tests/tensor_dataclass/test_setitem.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def test_setitem_broadcast_with_tdc(self, test_name, idx, source_shape):
195195
"too_many_indices",
196196
(slice(None), slice(None), 0),
197197
(20, 5), # Assign a valid TDC
198-
RuntimeError, # Changed from IndexError to RuntimeError
199-
r"Issue with key \.(features|labels) and index \(slice\(None, None, None\), slice\(None, None, None\), 0\) for value of shape torch\.Size\(\[20, 5, 10\]\) and type <class 'torch\.Tensor'> and assignment of shape \(20, 5\)",
198+
IndexError,
199+
r"too many indices for container: container is \d+-dimensional, but \d+ were indexed",
200200
),
201201
]
202202

tests/tensor_dict/test_getitem.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,12 @@ def to_device(obj):
124124
assert normalize_device(nested["b"].device) == normalize_device(
125125
torch.device(device)
126126
)
127+
128+
129+
def test_invalid_getitem_raises_error():
130+
td = TensorDict({"a": torch.randn(2, 3, 4)}, shape=[2, 3])
131+
with pytest.raises(
132+
IndexError,
133+
match="too many indices for container: container is 2-dimensional, but 3 were indexed",
134+
):
135+
td[:, :, 0]

0 commit comments

Comments
 (0)