2222from torch .utils ._pytree import Context , KeyEntry , PyTree
2323from typing_extensions import Self , TypeAlias
2424
25- from tensorcontainer .types import DeviceLike , ShapeLike
25+ from tensorcontainer .types import DeviceLike , IndexType , ShapeLike
2626from tensorcontainer .utils import (
2727 ContextWithAnalysis ,
2828 diagnose_pytree_structure_mismatch ,
@@ -429,6 +429,43 @@ def copy(self) -> Self:
429429 return self ._tree_map (lambda x : x , self )
430430
431431 def get_number_of_consuming_dims (self , item ) -> int :
432+ """
433+ Returns the number of container dimensions consumed by an indexing item.
434+
435+ This method is crucial for ellipsis expansion calculation. "Consuming" means
436+ the index item selects from existing container dimensions, reducing the
437+ container's rank. "Non-consuming" items either don't affect existing
438+ dimensions (Ellipsis) or add new dimensions (None).
439+
440+ Args:
441+ item: An indexing element from an indexing tuple
442+
443+ Returns:
444+ Number of container dimensions this item consumes:
445+ - 0 for non-consuming items (Ellipsis, None)
446+ - item.ndim for boolean tensors (advanced indexing)
447+ - 1 for standard consuming items (int, slice, non-bool tensor)
448+
449+ Examples:
450+ >>> container.get_number_of_consuming_dims(0) # int
451+ 1
452+ >>> container.get_number_of_consuming_dims(slice(0, 2)) # slice
453+ 1
454+ >>> container.get_number_of_consuming_dims(...) # Ellipsis
455+ 0
456+ >>> container.get_number_of_consuming_dims(None) # None (newaxis)
457+ 0
458+ >>> bool_mask = torch.tensor([[True, False], [False, True]])
459+ >>> container.get_number_of_consuming_dims(bool_mask) # 2D bool tensor
460+ 2
461+ >>> indices = torch.tensor([0, 2, 1])
462+ >>> container.get_number_of_consuming_dims(indices) # non-bool tensor
463+ 1
464+
465+ Note:
466+ Used internally by transform_ellipsis_index to calculate how many ':'
467+ slices the ellipsis should expand to: rank - sum(consuming_dims)
468+ """
432469 if item is Ellipsis or item is None :
433470 return 0
434471 if isinstance (item , torch .Tensor ) and item .dtype == torch .bool :
@@ -438,41 +475,84 @@ def get_number_of_consuming_dims(self, item) -> int:
438475
439476 def transform_ellipsis_index (self , shape : torch .Size , idx : tuple ) -> tuple :
440477 """
441- Transforms an indexing tuple with an ellipsis into an equivalent one without it.
442- ...
478+ Transforms an indexing tuple with ellipsis relative to container batch shape.
479+
480+ This method is essential for TensorContainer's design: containers have batch dimensions
481+ (self.shape) but contain individual tensors with varying total shapes. Without this
482+ preprocessing, ellipsis (...) would expand differently for each tensor based on its
483+ individual shape, violating container semantics and batch/event dimension boundaries.
484+
485+ Args:
486+ shape: The container's batch shape (self.shape), used as reference for ellipsis expansion
487+ idx: Indexing tuple potentially containing ellipsis (...)
488+
489+ Returns:
490+ Equivalent indexing tuple with ellipsis expanded to explicit slices
491+
492+ Example:
493+ Container with shape (4, 3) containing tensors (4, 3, 128) and (4, 3, 6, 64):
494+
495+ # User indexing: container[..., 0]
496+ # This method transforms: (..., 0) -> (:, 0) based on container batch shape (4, 3)
497+ # Applied to tensors: [:, 0] works consistently on both tensor shapes
498+ # Result: Container shape becomes (4,) with tensors (4, 128) and (4, 6, 64)
499+
500+ # Without this preprocessing, PyTorch would expand ellipsis per-tensor:
501+ # Tensor (4, 3, 128): [..., 0] -> [:, :, :, 0] (invalid - too many indices)
502+ # Tensor (4, 3, 6, 64): [..., 0] -> [:, :, :, :, 0] (invalid - too many indices)
503+
504+ Raises:
505+ IndexError: If multiple ellipsis found or too many indices for container dimensions
506+
507+ Note:
508+ This method is called internally during __getitem__ and __setitem__ operations
509+ to ensure consistent indexing behavior across all tensors in the container.
443510 """
444- # Count how many items in the index "consume" an axis from the original shape.
445- # `None` adds a new axis, so it's not counted.
511+ # Step 1: Count indices that "consume" container dimensions
512+ # - Ellipsis (...) and None don't consume dims (Ellipsis is placeholder, None adds new dim)
513+ # - int, slice, tensor indices consume 1 dim each (bool tensor consumes its ndim)
514+ # Example: (..., 0, :) has 2 consuming indices (0 and :), ellipsis doesn't count
446515 num_consuming_indices = sum (
447516 self .get_number_of_consuming_dims (item ) for item in idx
448517 )
449518
519+ # Step 2: Validate that we don't have more indices than container dimensions
520+ # Container shape (4, 3) has rank=2, so max 2 consuming indices allowed
450521 rank = len (shape )
451522 if num_consuming_indices > rank :
452523 raise IndexError (
453524 f"too many indices for container: container is { rank } -dimensional, "
454525 f"but { num_consuming_indices } were indexed"
455526 )
456527
528+ # Step 3: Early return if no ellipsis - nothing to transform
457529 if Ellipsis not in idx :
458530 return idx
459531
532+ # Step 4: Validate only one ellipsis exists (PyTorch/NumPy requirement)
460533 ellipsis_count = 0
461534 for item in idx :
462535 if item is Ellipsis :
463536 ellipsis_count += 1
464537 if ellipsis_count > 1 :
465538 raise IndexError ("an index can only have a single ellipsis ('...')" )
466539
540+ # Step 5: Core calculation - determine how many ':' slices ellipsis should expand to
541+ # Example: Container shape (4, 3), index (..., 0)
542+ # - rank=2, consuming_indices=1 -> ellipsis expands to 2-1=1 slice
543+ # - Result: (..., 0) becomes (:, 0)
467544 ellipsis_pos = idx .index (Ellipsis )
468-
469- # Calculate slices needed based on the consuming indices
470545 num_slices_to_add = rank - num_consuming_indices
471546
472- part_before_ellipsis = idx [:ellipsis_pos ]
473- part_after_ellipsis = idx [ellipsis_pos + 1 :]
474- ellipsis_replacement = (slice (None ),) * num_slices_to_add
547+ # Step 6: Reconstruct index tuple by replacing ellipsis with explicit slices
548+ # Split around ellipsis: (a, ..., b) -> (a,) + (:, :, ...) + (b,)
549+ part_before_ellipsis = idx [:ellipsis_pos ] # Everything before ...
550+ part_after_ellipsis = idx [ellipsis_pos + 1 :] # Everything after ...
551+ ellipsis_replacement = (
552+ slice (None ),
553+ ) * num_slices_to_add # (:, :, ...) - the ':' slices
475554
555+ # Combine parts: before + replacement + after
476556 final_index = part_before_ellipsis + ellipsis_replacement + part_after_ellipsis
477557
478558 return final_index
@@ -512,7 +592,7 @@ def _format_item(key, value):
512592 f")"
513593 )
514594
515- def __getitem__ (self : Self , key : Any ) -> Self :
595+ def __getitem__ (self : Self , key : IndexType ) -> Self :
516596 """Index into the container along batch dimensions.
517597
518598 Indexing operations are applied to the batch dimensions of all contained tensors.
@@ -557,9 +637,10 @@ def __getitem__(self: Self, key: Any) -> Self:
557637 raise IndexError (
558638 "Cannot index a 0-dimensional TensorContainer with a single index. Use a tuple of indices matching the batch shape, or an empty tuple for a scalar."
559639 )
640+
560641 return TensorContainer ._tree_map (lambda x : x [key ], self )
561642
562- def __setitem__ (self : Self , index : Any , value : Self ) -> None :
643+ def __setitem__ (self : Self , index : IndexType , value : Self ) -> None :
563644 """
564645 Sets the value of a slice of the container in-place.
565646
@@ -582,16 +663,16 @@ def __setitem__(self: Self, index: Any, value: Self) -> None:
582663 raise ValueError (f"Invalid value. Expected value of type { type (self )} " )
583664
584665 processed_index = index
585- if isinstance (processed_index , tuple ):
666+ if isinstance (index , tuple ):
586667 processed_index = self .transform_ellipsis_index (self .shape , index )
587668
588- for k , v in self ._pytree_flatten_with_keys_fn ()[0 ]:
589- try :
590- v [processed_index ] = k .get (value )
591- except Exception as e :
592- raise type (e )(
593- f"Issue with key { str (k )} and index { processed_index } for value of shape { v .shape } and type { type (v )} and assignment of shape { tuple (value .shape )} "
594- ) from e
669+ for k , v in self ._pytree_flatten_with_keys_fn ()[0 ]:
670+ try :
671+ v [processed_index ] = k .get (value )
672+ except Exception as e :
673+ raise type (e )(
674+ f"Issue with key { str (k )} and index { processed_index } for value of shape { v .shape } and type { type (v )} and assignment of shape { tuple (value .shape )} "
675+ ) from e
595676
596677 def view (self : Self , * shape : int ) -> Self :
597678 """Return a view with modified batch dimensions, preserving event dimensions.
0 commit comments