Skip to content

Commit 01901e2

Browse files
authored
Merge pull request #21 from mctigger/tree_map-error-messages
Improved PyTree Error Messages and Refactoring
2 parents 233da76 + 1338533 commit 01901e2

File tree

10 files changed

+962
-94
lines changed

10 files changed

+962
-94
lines changed

examples/tensor_dataclass/04_stack.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ class DataPoint(TensorDataClass):
2525
y: torch.Tensor
2626

2727

28+
class NotADataPoint(TensorDataClass):
29+
a: torch.Tensor
30+
31+
2832
def main() -> None:
2933
"""Demonstrate stacking TensorDataClass instances."""
3034
point1 = DataPoint(
@@ -49,6 +53,13 @@ def main() -> None:
4953
assert stacked.x.shape == (2, 3, 4)
5054
assert stacked.y.shape == (2, 3, 5)
5155

56+
# Attempt to stack with different TensorDataClasses
57+
try:
58+
not_a_point = NotADataPoint(a=torch.rand(3, 4), shape=(3,), device="cpu")
59+
torch.stack([point1, not_a_point], dim=0)
60+
except Exception as e:
61+
print(e)
62+
5263

5364
if __name__ == "__main__":
5465
main()

src/tensorcontainer/tensor_annotated.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from __future__ import annotations
22

3+
from dataclasses import dataclass
34
from typing import Any, Iterable, TypeVar, Union, get_args
45

56
from torch import Tensor
67
from torch.utils import _pytree as pytree
78
from typing_extensions import Self
89

9-
from tensorcontainer.tensor_container import TensorContainer
10+
from tensorcontainer.tensor_container import (
11+
TensorContainer,
12+
TensorContainerPytreeContext,
13+
)
1014
from tensorcontainer.types import DeviceLike, ShapeLike
1115
from tensorcontainer.utils import PytreeRegistered
1216

@@ -17,6 +21,54 @@
1721
T_TensorAnnotated = TypeVar("T_TensorAnnotated", bound="TensorAnnotated")
1822

1923

24+
# PyTree context metadata for reconstruction
25+
@dataclass
26+
class TensorAnnotatedPytreeContext(
27+
TensorContainerPytreeContext["TensorAnnotatedPytreeContext"]
28+
):
29+
"""TensorAnnotated PyTree context with enhanced error messages."""
30+
31+
keys: list[str]
32+
event_ndims: list[int]
33+
metadata: dict[str, Any]
34+
35+
def __str__(self) -> str:
36+
"""Return human-readable description of this TensorAnnotated context."""
37+
# Try to get the actual class name from metadata
38+
class_name = self.metadata.get("class_name", "TensorDataClass")
39+
40+
fields_str = f"fields={self.keys}"
41+
device_str = f"device={self.device}"
42+
43+
return f"{class_name}({fields_str}, {device_str})"
44+
45+
def analyze_mismatch_with(
46+
self, other: TensorAnnotatedPytreeContext, entry_index: int
47+
) -> str:
48+
"""Analyze specific mismatches between TensorAnnotated contexts."""
49+
# Start with base class analysis (device mismatch, if any)
50+
guidance = super().analyze_mismatch_with(other, entry_index)
51+
52+
# Add TensorAnnotated-specific analysis
53+
self_fields = set(self.keys)
54+
other_fields = set(other.keys)
55+
56+
if self_fields != other_fields:
57+
missing = self_fields - other_fields
58+
extra = other_fields - self_fields
59+
guidance += "Field mismatch detected."
60+
if missing:
61+
guidance += (
62+
f" Missing fields in container {entry_index}: {sorted(missing)}."
63+
)
64+
if extra:
65+
guidance += (
66+
f" Extra fields in container {entry_index}: {sorted(extra)}."
67+
)
68+
69+
return guidance
70+
71+
2072
class TensorAnnotated(TensorContainer, PytreeRegistered):
2173
def __init__(self, shape: ShapeLike, device: DeviceLike | None):
2274
super().__init__(shape, device)
@@ -73,11 +125,13 @@ def _get_meta_attributes(self):
73125

74126
def _get_pytree_context(
75127
self, flat_names: list[str], flat_leaves: list[TDCompatible], meta_data
76-
) -> tuple:
128+
) -> TensorAnnotatedPytreeContext:
77129
batch_ndim = len(self.shape)
78-
event_ndims = tuple(leaf.ndim - batch_ndim for leaf in flat_leaves)
130+
event_ndims = [leaf.ndim - batch_ndim for leaf in flat_leaves]
79131

80-
return flat_names, event_ndims, meta_data, self.device
132+
return TensorAnnotatedPytreeContext(
133+
self.device, flat_names, event_ndims, meta_data
134+
)
81135

82136
def _pytree_flatten(self) -> tuple[list[Any], Any]:
83137
tensor_attributes = self._get_tensor_attributes()
@@ -94,15 +148,20 @@ def _pytree_flatten_with_keys_fn(
94148
self,
95149
) -> tuple[list[tuple[pytree.KeyEntry, Any]], Any]:
96150
flat_values, context = self._pytree_flatten()
97-
flat_names = context[0]
151+
flat_names = context.keys
98152
name_value_tuples = [
99153
(pytree.GetAttrKey(k), v) for k, v in zip(flat_names, flat_values)
100154
]
101155
return name_value_tuples, context # type: ignore[return-value]
102156

103157
@classmethod
104-
def _pytree_unflatten(cls, leaves: Iterable[Any], context: pytree.Context) -> Self:
105-
flat_names, event_ndims, meta_data, device = context
158+
def _pytree_unflatten(
159+
cls, leaves: Iterable[Any], context: TensorAnnotatedPytreeContext
160+
) -> Self:
161+
flat_names = context.keys
162+
event_ndims = context.event_ndims
163+
device = context.device
164+
meta_data = context.metadata
106165

107166
leaves = list(leaves) # Convert to list to allow indexing
108167

src/tensorcontainer/tensor_container.py

Lines changed: 63 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
from __future__ import annotations
22

3+
from dataclasses import dataclass
34
import functools
45
import textwrap
56
import threading
6-
from abc import abstractmethod
7+
from abc import ABC, abstractmethod
78
from contextlib import contextmanager
8-
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union
9+
from typing import (
10+
Any,
11+
Callable,
12+
Generic,
13+
Iterable,
14+
List,
15+
Optional,
16+
Tuple,
17+
Type,
18+
TypeVar,
19+
Union,
20+
)
921

1022
import torch
1123

@@ -15,7 +27,12 @@
1527
from typing_extensions import Self, TypeAlias
1628

1729
from tensorcontainer.types import DeviceLike, ShapeLike
18-
from tensorcontainer.utils import resolve_device
30+
from tensorcontainer.utils import (
31+
ContextWithAnalysis,
32+
diagnose_pytree_structure_mismatch,
33+
resolve_device,
34+
format_path,
35+
)
1936

2037
HANDLED_FUNCTIONS = {}
2138

@@ -33,6 +50,25 @@ def decorator(func):
3350
return decorator
3451

3552

53+
U = TypeVar("U", bound="TensorContainerPytreeContext")
54+
55+
56+
@dataclass
57+
class TensorContainerPytreeContext(ContextWithAnalysis[U], Generic[U], ABC):
58+
"""Base PyTree context class for tensor containers with common device handling."""
59+
60+
device: torch.device | None
61+
62+
def analyze_mismatch_with(self, other: U, entry_index: int) -> str:
63+
"""Analyze mismatches with another TensorContainerPytreeContext, starting with device analysis."""
64+
# Check device mismatch first
65+
if self.device != other.device:
66+
return f"Device mismatch: container 0 device={self.device}, container {entry_index} device={other.device}. "
67+
68+
# If devices match, return empty string for subclasses to add their analysis
69+
return ""
70+
71+
3672
class TensorContainer:
3773
"""A foundational base class for PyTree-compatible tensor containers with batch semantics.
3874
@@ -308,11 +344,25 @@ def wrapped_func(keypath, x, *xs):
308344
try:
309345
return func(x, *xs)
310346
except Exception as e:
311-
path = cls._format_path(keypath)
347+
path = format_path(keypath)
312348
message = f"Error at path {path}: {type(e).__name__}: {e}"
313349
raise type(e)(message) from e
314350

315-
return cls.tree_map_with_path(wrapped_func, tree, *rests, is_leaf=is_leaf)
351+
try:
352+
return pytree.tree_map_with_path(
353+
wrapped_func, tree, *rests, is_leaf=is_leaf
354+
)
355+
except Exception as e:
356+
# The following code is just to provide better error messages for operations that
357+
# work on multiple pytrees such as torch.stack() or torch.cat()
358+
# It is not necessary for TensorContainer to function properly.
359+
if len(rests) > 0:
360+
msg = diagnose_pytree_structure_mismatch(tree, *rests, is_leaf=is_leaf)
361+
if msg:
362+
raise RuntimeError(msg) from e
363+
364+
# Re-raise if it is an unknown error.
365+
raise e
316366

317367
@classmethod
318368
def tree_map_with_path(
@@ -431,22 +481,6 @@ def transform_ellipsis_index(self, shape: torch.Size, idx: tuple) -> tuple:
431481

432482
return final_index
433483

434-
@classmethod
435-
def _format_path(cls, path: pytree.KeyPath) -> str:
436-
"""Helper to format a PyTree KeyPath into a readable string."""
437-
parts = []
438-
for entry in path:
439-
if isinstance(entry, tuple): # Handle nested KeyPath tuples
440-
parts.append(cls._format_path(entry))
441-
else:
442-
parts.append(str(entry))
443-
444-
# Join parts and clean up leading dots if any
445-
formatted_path = "".join(parts)
446-
if formatted_path.startswith("."):
447-
formatted_path = formatted_path[1:]
448-
return formatted_path
449-
450484
def __repr__(self) -> str:
451485
# Use a consistent indent of 4 spaces, which is standard
452486
indent = " "
@@ -863,14 +897,10 @@ def _stack(
863897
dim = dim + batch_ndim + 1
864898

865899
if dim < 0 or dim > batch_ndim:
866-
raise IndexError("Dimension out of range")
867-
868-
shape_expected = first_tc.shape
869-
870-
for t in tensors:
871-
shape_is = t.shape
872-
if shape_is != shape_expected:
873-
raise ValueError("stack expects each TensorContainer to be equal size")
900+
raise IndexError(
901+
f"Dimension {dim - batch_ndim - 1 if dim < 0 else dim} out of range "
902+
f"(expected 0 to {batch_ndim} for stack operation on shape {tuple(first_tc.shape)})"
903+
)
874904

875905
# Pytree handles the stacking of individual tensors and metadata consistency
876906
result_td = TensorContainer._tree_map(lambda *x: torch.stack(x, dim), *tensors)
@@ -891,16 +921,10 @@ def _cat(
891921
dim = dim + batch_ndim
892922

893923
if dim < 0 or dim > batch_ndim - 1:
894-
raise IndexError("Dimension out of range")
895-
896-
shape_expected = first_tc.shape[:dim] + first_tc.shape[dim + 1 :]
897-
898-
for t in tensors:
899-
shape_is = t.shape[:dim] + t.shape[dim + 1 :]
900-
if shape_is != shape_expected:
901-
raise ValueError(
902-
"TensorContainer batch shapes must be identical except for 'dim'"
903-
)
924+
raise IndexError(
925+
f"Dimension {dim - batch_ndim if dim < 0 else dim} out of range "
926+
f"(expected 0 to {batch_ndim - 1} for concatenation on shape {tuple(first_tc.shape)})"
927+
)
904928

905929
# Create a new TensorContainer of the same type as the first one
906930
# and apply torch.cat to its internal tensors

src/tensorcontainer/tensor_dict.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616

1717
from __future__ import annotations
1818

19+
from dataclasses import dataclass
1920
from typing import (
2021
Any,
2122
Dict,
2223
Iterable,
2324
List,
2425
Mapping,
25-
NamedTuple,
2626
Tuple,
2727
Union,
2828
cast,
@@ -38,20 +38,56 @@
3838
PyTree,
3939
)
4040

41-
from tensorcontainer.tensor_container import TensorContainer
41+
from tensorcontainer.tensor_container import (
42+
TensorContainer,
43+
TensorContainerPytreeContext,
44+
)
4245
from tensorcontainer.types import DeviceLike, ShapeLike
4346
from tensorcontainer.utils import PytreeRegistered
4447

4548
TDCompatible = Union[Tensor, TensorContainer]
4649

4750

4851
# PyTree context metadata for reconstruction
49-
class TensorDictPytreeContext(NamedTuple):
50-
keys: Tuple[str, ...]
51-
event_ndims: Tuple[int, ...]
52-
device_context: torch.device | None
52+
@dataclass
53+
class TensorDictPytreeContext(TensorContainerPytreeContext["TensorDictPytreeContext"]):
54+
"""TensorDict PyTree context with enhanced error messages."""
55+
56+
keys: list[str]
57+
event_ndims: list[int]
5358
metadata: Dict[str, Any]
5459

60+
def __str__(self) -> str:
61+
"""Return human-readable description of this TensorDict context."""
62+
keys_str = f"keys={list(self.keys)}" if self.keys else "keys=[]"
63+
device_str = f"device={self.device}" if self.device else "device=None"
64+
65+
return f"TensorDict({keys_str}, {device_str})"
66+
67+
def analyze_mismatch_with(
68+
self, other: TensorDictPytreeContext, entry_index: int
69+
) -> str:
70+
"""Analyze specific mismatches with another TensorDict context."""
71+
# Start with base class analysis (device mismatch, if any)
72+
guidance = super().analyze_mismatch_with(other, entry_index)
73+
74+
# Add TensorDict-specific analysis
75+
self_keys = set(self.keys)
76+
other_keys = set(other.keys)
77+
78+
if self_keys != other_keys:
79+
missing = self_keys - other_keys
80+
extra = other_keys - self_keys
81+
guidance += "Key mismatch detected."
82+
if missing:
83+
guidance += (
84+
f" Missing keys in container {entry_index}: {sorted(missing)}."
85+
)
86+
if extra:
87+
guidance += f" Extra keys in container {entry_index}: {sorted(extra)}."
88+
89+
return guidance
90+
5591

5692
class TensorDict(TensorContainer, PytreeRegistered):
5793
"""Dictionary-like container for batched tensors that share the same leading batch shape.
@@ -159,7 +195,7 @@ def _get_pytree_context(
159195
"""
160196
batch_ndim = len(self.shape)
161197
event_ndims = tuple(leaf.ndim - batch_ndim for leaf in flat_leaves)
162-
return TensorDictPytreeContext(tuple(keys), event_ndims, self.device, metadata)
198+
return TensorDictPytreeContext(self.device, tuple(keys), event_ndims, metadata)
163199

164200
def _pytree_flatten(
165201
self,
@@ -225,8 +261,11 @@ def _pytree_unflatten(
225261
- If no leaves are provided, an empty ``TensorDict`` is constructed using the shape
226262
from the context. The device is restored from the context.
227263
"""
228-
# Unpack context tuple
229-
keys, event_ndims, device_context, metadata = context
264+
# Access context fields
265+
keys = context.keys
266+
event_ndims = context.event_ndims
267+
device_context = context.device
268+
metadata = context.metadata
230269

231270
obj = cls.__new__(cls)
232271
obj.device = device_context

0 commit comments

Comments
 (0)