Skip to content

Commit aa46357

Browse files
authored
Merge pull request #25 from mctigger/fix-setitem
Fix __setitem__ in TensorContainer
2 parents b377899 + 131a09b commit aa46357

File tree

4 files changed

+333
-64
lines changed

4 files changed

+333
-64
lines changed

src/tensorcontainer/tensor_container.py

Lines changed: 101 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch.utils._pytree import Context, KeyEntry, PyTree
2323
from typing_extensions import Self, TypeAlias
2424

25-
from tensorcontainer.types import DeviceLike, ShapeLike
25+
from tensorcontainer.types import DeviceLike, IndexType, ShapeLike
2626
from tensorcontainer.utils import (
2727
ContextWithAnalysis,
2828
diagnose_pytree_structure_mismatch,
@@ -429,6 +429,43 @@ def copy(self) -> Self:
429429
return self._tree_map(lambda x: x, self)
430430

431431
def get_number_of_consuming_dims(self, item) -> int:
432+
"""
433+
Returns the number of container dimensions consumed by an indexing item.
434+
435+
This method is crucial for ellipsis expansion calculation. "Consuming" means
436+
the index item selects from existing container dimensions, reducing the
437+
container's rank. "Non-consuming" items either don't affect existing
438+
dimensions (Ellipsis) or add new dimensions (None).
439+
440+
Args:
441+
item: An indexing element from an indexing tuple
442+
443+
Returns:
444+
Number of container dimensions this item consumes:
445+
- 0 for non-consuming items (Ellipsis, None)
446+
- item.ndim for boolean tensors (advanced indexing)
447+
- 1 for standard consuming items (int, slice, non-bool tensor)
448+
449+
Examples:
450+
>>> container.get_number_of_consuming_dims(0) # int
451+
1
452+
>>> container.get_number_of_consuming_dims(slice(0, 2)) # slice
453+
1
454+
>>> container.get_number_of_consuming_dims(...) # Ellipsis
455+
0
456+
>>> container.get_number_of_consuming_dims(None) # None (newaxis)
457+
0
458+
>>> bool_mask = torch.tensor([[True, False], [False, True]])
459+
>>> container.get_number_of_consuming_dims(bool_mask) # 2D bool tensor
460+
2
461+
>>> indices = torch.tensor([0, 2, 1])
462+
>>> container.get_number_of_consuming_dims(indices) # non-bool tensor
463+
1
464+
465+
Note:
466+
Used internally by transform_ellipsis_index to calculate how many ':'
467+
slices the ellipsis should expand to: rank - sum(consuming_dims)
468+
"""
432469
if item is Ellipsis or item is None:
433470
return 0
434471
if isinstance(item, torch.Tensor) and item.dtype == torch.bool:
@@ -438,41 +475,84 @@ def get_number_of_consuming_dims(self, item) -> int:
438475

439476
def transform_ellipsis_index(self, shape: torch.Size, idx: tuple) -> tuple:
440477
"""
441-
Transforms an indexing tuple with an ellipsis into an equivalent one without it.
442-
...
478+
Transforms an indexing tuple with ellipsis relative to container batch shape.
479+
480+
This method is essential for TensorContainer's design: containers have batch dimensions
481+
(self.shape) but contain individual tensors with varying total shapes. Without this
482+
preprocessing, ellipsis (...) would expand differently for each tensor based on its
483+
individual shape, violating container semantics and batch/event dimension boundaries.
484+
485+
Args:
486+
shape: The container's batch shape (self.shape), used as reference for ellipsis expansion
487+
idx: Indexing tuple potentially containing ellipsis (...)
488+
489+
Returns:
490+
Equivalent indexing tuple with ellipsis expanded to explicit slices
491+
492+
Example:
493+
Container with shape (4, 3) containing tensors (4, 3, 128) and (4, 3, 6, 64):
494+
495+
# User indexing: container[..., 0]
496+
# This method transforms: (..., 0) -> (:, 0) based on container batch shape (4, 3)
497+
# Applied to tensors: [:, 0] works consistently on both tensor shapes
498+
# Result: Container shape becomes (4,) with tensors (4, 128) and (4, 6, 64)
499+
500+
# Without this preprocessing, PyTorch would expand ellipsis per-tensor:
501+
# Tensor (4, 3, 128): [..., 0] -> [:, :, :, 0] (invalid - too many indices)
502+
# Tensor (4, 3, 6, 64): [..., 0] -> [:, :, :, :, 0] (invalid - too many indices)
503+
504+
Raises:
505+
IndexError: If multiple ellipsis found or too many indices for container dimensions
506+
507+
Note:
508+
This method is called internally during __getitem__ and __setitem__ operations
509+
to ensure consistent indexing behavior across all tensors in the container.
443510
"""
444-
# Count how many items in the index "consume" an axis from the original shape.
445-
# `None` adds a new axis, so it's not counted.
511+
# Step 1: Count indices that "consume" container dimensions
512+
# - Ellipsis (...) and None don't consume dims (Ellipsis is placeholder, None adds new dim)
513+
# - int, slice, tensor indices consume 1 dim each (bool tensor consumes its ndim)
514+
# Example: (..., 0, :) has 2 consuming indices (0 and :), ellipsis doesn't count
446515
num_consuming_indices = sum(
447516
self.get_number_of_consuming_dims(item) for item in idx
448517
)
449518

519+
# Step 2: Validate that we don't have more indices than container dimensions
520+
# Container shape (4, 3) has rank=2, so max 2 consuming indices allowed
450521
rank = len(shape)
451522
if num_consuming_indices > rank:
452523
raise IndexError(
453524
f"too many indices for container: container is {rank}-dimensional, "
454525
f"but {num_consuming_indices} were indexed"
455526
)
456527

528+
# Step 3: Early return if no ellipsis - nothing to transform
457529
if Ellipsis not in idx:
458530
return idx
459531

532+
# Step 4: Validate only one ellipsis exists (PyTorch/NumPy requirement)
460533
ellipsis_count = 0
461534
for item in idx:
462535
if item is Ellipsis:
463536
ellipsis_count += 1
464537
if ellipsis_count > 1:
465538
raise IndexError("an index can only have a single ellipsis ('...')")
466539

540+
# Step 5: Core calculation - determine how many ':' slices ellipsis should expand to
541+
# Example: Container shape (4, 3), index (..., 0)
542+
# - rank=2, consuming_indices=1 -> ellipsis expands to 2-1=1 slice
543+
# - Result: (..., 0) becomes (:, 0)
467544
ellipsis_pos = idx.index(Ellipsis)
468-
469-
# Calculate slices needed based on the consuming indices
470545
num_slices_to_add = rank - num_consuming_indices
471546

472-
part_before_ellipsis = idx[:ellipsis_pos]
473-
part_after_ellipsis = idx[ellipsis_pos + 1 :]
474-
ellipsis_replacement = (slice(None),) * num_slices_to_add
547+
# Step 6: Reconstruct index tuple by replacing ellipsis with explicit slices
548+
# Split around ellipsis: (a, ..., b) -> (a,) + (:, :, ...) + (b,)
549+
part_before_ellipsis = idx[:ellipsis_pos] # Everything before ...
550+
part_after_ellipsis = idx[ellipsis_pos + 1 :] # Everything after ...
551+
ellipsis_replacement = (
552+
slice(None),
553+
) * num_slices_to_add # (:, :, ...) - the ':' slices
475554

555+
# Combine parts: before + replacement + after
476556
final_index = part_before_ellipsis + ellipsis_replacement + part_after_ellipsis
477557

478558
return final_index
@@ -512,7 +592,7 @@ def _format_item(key, value):
512592
f")"
513593
)
514594

515-
def __getitem__(self: Self, key: Any) -> Self:
595+
def __getitem__(self: Self, key: IndexType) -> Self:
516596
"""Index into the container along batch dimensions.
517597
518598
Indexing operations are applied to the batch dimensions of all contained tensors.
@@ -557,9 +637,10 @@ def __getitem__(self: Self, key: Any) -> Self:
557637
raise IndexError(
558638
"Cannot index a 0-dimensional TensorContainer with a single index. Use a tuple of indices matching the batch shape, or an empty tuple for a scalar."
559639
)
640+
560641
return TensorContainer._tree_map(lambda x: x[key], self)
561642

562-
def __setitem__(self: Self, index: Any, value: Self) -> None:
643+
def __setitem__(self: Self, index: IndexType, value: Self) -> None:
563644
"""
564645
Sets the value of a slice of the container in-place.
565646
@@ -582,16 +663,16 @@ def __setitem__(self: Self, index: Any, value: Self) -> None:
582663
raise ValueError(f"Invalid value. Expected value of type {type(self)}")
583664

584665
processed_index = index
585-
if isinstance(processed_index, tuple):
666+
if isinstance(index, tuple):
586667
processed_index = self.transform_ellipsis_index(self.shape, index)
587668

588-
for k, v in self._pytree_flatten_with_keys_fn()[0]:
589-
try:
590-
v[processed_index] = k.get(value)
591-
except Exception as e:
592-
raise type(e)(
593-
f"Issue with key {str(k)} and index {processed_index} for value of shape {v.shape} and type {type(v)} and assignment of shape {tuple(value.shape)}"
594-
) from e
669+
for k, v in self._pytree_flatten_with_keys_fn()[0]:
670+
try:
671+
v[processed_index] = k.get(value)
672+
except Exception as e:
673+
raise type(e)(
674+
f"Issue with key {str(k)} and index {processed_index} for value of shape {v.shape} and type {type(v)} and assignment of shape {tuple(value.shape)}"
675+
) from e
595676

596677
def view(self: Self, *shape: int) -> Self:
597678
"""Return a view with modified batch dimensions, preserving event dimensions.

src/tensorcontainer/tensor_dict.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
)
2727
from collections.abc import Iterable, Mapping
2828

29-
import torch
3029
from torch import Tensor
3130
from torch.utils._pytree import (
3231
KeyEntry,
@@ -38,7 +37,7 @@
3837
TensorContainer,
3938
TensorContainerPytreeContext,
4039
)
41-
from tensorcontainer.types import DeviceLike, ShapeLike
40+
from tensorcontainer.types import DeviceLike, IndexType, ShapeLike
4241
from tensorcontainer.utils import PytreeRegistered
4342

4443
TDCompatible = Union[Tensor, TensorContainer]
@@ -292,12 +291,9 @@ def _pytree_unflatten(
292291
def __getitem__(self, key: str) -> Any: ...
293292

294293
@overload
295-
def __getitem__(self, key: slice) -> TensorDict: ...
294+
def __getitem__(self, key: IndexType) -> TensorDict: ...
296295

297-
@overload
298-
def __getitem__(self, key: Tensor) -> TensorDict: ...
299-
300-
def __getitem__(self, key: Any) -> Any:
296+
def __getitem__(self, key: str | IndexType) -> Any:
301297
if isinstance(key, str):
302298
return self.data[key]
303299

@@ -307,13 +303,9 @@ def __getitem__(self, key: Any) -> Any:
307303
def __setitem__(self, key: str, value: Any) -> None: ...
308304

309305
@overload
310-
def __setitem__(
311-
self,
312-
key: slice | Tensor | int | tuple[slice | Tensor | int, ...],
313-
value: Any,
314-
) -> None: ...
306+
def __setitem__(self, key: IndexType, value: Any) -> None: ...
315307

316-
def __setitem__(self, key: Any, value: Any) -> None:
308+
def __setitem__(self, key: str | IndexType, value: Any) -> None:
317309
if isinstance(key, str):
318310
if isinstance(value, dict):
319311
value = TensorDict(value, self.shape, self.device)
@@ -323,13 +315,6 @@ def __setitem__(self, key: Any, value: Any) -> None:
323315

324316
self.data[key] = value
325317
else:
326-
# Handle slicing/indexing assignments via TensorContainer
327-
if isinstance(value, (float, int)):
328-
# Promote Python scalars to tensors to support slice assignment paths
329-
value = torch.tensor(
330-
value, device=self.device, dtype=torch.float32
331-
) # scalar promotion for container setitem
332-
333318
super().__setitem__(key, value)
334319

335320
def __delitem__(self, key: str):

src/tensorcontainer/types.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,19 @@
66
will fail if PyTorch changes it in the future.
77
"""
88

9-
from typing import Union
9+
from typing import List, Union
1010
import torch
1111

12+
# Define EllipsisType for Python 3.9 compatibility
13+
EllipsisType = type(...)
14+
1215
# Mirror torch._prims_common.ShapeType without importing it directly.
1316
ShapeLike = Union[torch.Size, list[int], tuple[int, ...]]
1417

1518
# Mirror torch._prims_common.DeviceLikeType without importing it directly.
1619
DeviceLike = Union[str, torch.device, int]
1720

18-
__all__ = ["ShapeLike", "DeviceLike"]
21+
# Type for tensor indexing operations, covering all supported PyTorch indexing patterns
22+
IndexType = Union[int, slice, torch.Tensor, tuple, EllipsisType, None, List[int]]
23+
24+
__all__ = ["ShapeLike", "DeviceLike", "IndexType"]

0 commit comments

Comments
 (0)