Skip to content

Commit 605c804

Browse files
authored
Merge pull request #13 from mctigger/tensor-dict-types
This PR introduces significant improvements to the TensorContainer type system and documentation, along with a new symlog distribution implementation: - Enhanced type system: Added DeviceLike and ShapeLike type aliases for consistent typing across the codebase, replacing internal PyTorch types with public API equivalents - Comprehensive documentation: Added detailed design decisions document explaining TensorContainer architecture, batch/event dimension separation, PyTree integration, and performance optimization strategies• New symlog distribution: Implemented SymLogDistribution with bijective SymexpTransform for modeling data with wide dynamic ranges - Improved validation: Refactored tensor container initialization to always validate by default, removing the validate_args parameter • Better PyTree integration: Enhanced metadata preservation during pytree operations and improved nested container interaction tests - Device resolution improvements: Streamlined device compatibility logic with better error handling
2 parents 1d7d2d2 + 05b6cc4 commit 605c804

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+1982
-762
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
*Tensor containers for PyTorch with PyTree compatibility and torch.compile optimization*
44

5-
[![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/)
5+
[![Python 3.9, 3.10, 3.11, 3.12](https://img.shields.io/badge/python-3.10%20%7C%203.11%20%7C%203.11%20%7C%203.12-blue)](https://www.python.org/downloads/)
66
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
7-
[![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-red.svg)](https://pytorch.org/)
7+
[![PyTorch](https://img.shields.io/badge/PyTorch-2.6+-blue.svg)](https://pytorch.org/)
88

99
> **⚠️ Academic Research Project**: This project exists solely for academic purposes to explore and learn PyTorch internals. For production use, please use the official, well-maintained [**torch/tensordict**](https://github.com/pytorch/tensordict) library.
1010

docs/tensor_container.md

Lines changed: 308 additions & 0 deletions
Large diffs are not rendered by default.

docs/tensor_distribution/development.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@ All classes in [`tensorcontainer.tensor_distribution`](/src/tensorcontainer/tens
2626

2727
Many `torch.distributions` constructors accept parameters of type `Union[Number, Tensor]` or any specialization of `Number` (e.g. `float`). However, [`TensorContainer`](/src/tensorcontainer/tensor_container.py) and [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) can only process `Union[Tensor, TensorContainer]` objects and require all parameters to have compatible shapes for broadcasting.
2828

29-
**Implementation Rule**: When the constructor signature contains `Union[Number, Tensor]` or any specialization of `Number` parameters, implementations **must** use `torch.distributions.utils.broadcast_all` to:
29+
**Implementation Rule**: When the constructor signature contains `Union[Number, Tensor]` or any specialization of `Number` parameters, implementations **must** use `tensorcontainer.tensor_distribution.utils.broadcast_all` to:
3030
1. Convert scalar numbers to tensors
3131
2. Broadcast all parameters to a common shape
3232

33+
3334
This preprocessing ensures proper shape and device management within the [`TensorAnnotated`](/src/tensorcontainer/tensor_annotated.py) framework.
3435

35-
**Decision Criterion**: If the constructor signature does not contain `Union[Number, Tensor]` parameters, simpler parameter handling approaches should be preferred.
36+
**Decision Criterion**: If the constructor signature does not contain `Union[Number, Tensor]` parameters, simpler parameter handling approaches should be preferred. E.g. if it only contains a single argument of type `Tensor`, broadcasting is not necessary and should be avoided.
3637

3738
### Validation Strategy
3839

@@ -46,7 +47,7 @@ This preprocessing ensures proper shape and device management within the [`Tenso
4647

4748
Following the `torch.distributions.Distribution` pattern, basic distribution properties are provided through the [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) base class via delegation to `self.dist()`.
4849

49-
**Specialization Rule**: Distribution-specific properties **must** be implemented only in the corresponding subclass, maintaining the same delegation pattern to the underlying `torch.distributions` object.
50+
**Specialization Rule**: Distribution-specific properties (such as `logits` and `probs` in `Categorical`) **must** be implemented only in the corresponding subclass, maintaining the same delegation pattern to the underlying `torch.distributions` object.
5051

5152
## Implementation Patterns
5253

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .symlog import SymLogDistribution, SymexpTransform, symexp, symlog
2+
3+
__all__ = ["SymLogDistribution", "SymexpTransform", "symexp", "symlog"]
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
from torch.distributions import (
5+
Normal,
6+
Transform,
7+
TransformedDistribution,
8+
constraints,
9+
)
10+
from typing import Any
11+
12+
13+
def symlog(x: torch.Tensor) -> torch.Tensor:
14+
"""
15+
Applies the symlog function element-wise.
16+
17+
symlog(x) = sign(x) * log(1 + |x|)
18+
"""
19+
return torch.sign(x) * torch.log(1 + torch.abs(x))
20+
21+
22+
def symexp(x: torch.Tensor) -> torch.Tensor:
23+
"""
24+
Applies the symexp function element-wise.
25+
26+
symexp(x) = sign(x) * (exp(|x|) - 1)
27+
"""
28+
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
29+
30+
31+
class SymexpTransform(Transform):
32+
"""
33+
A bijective transform implementing the symexp function.
34+
35+
This transform is its own inverse, applying symlog. It is used to warp a
36+
base distribution into a symlog-space.
37+
"""
38+
39+
def __init__(self) -> None:
40+
super().__init__()
41+
self.bijective = True
42+
self.domain = constraints.real
43+
self.codomain = constraints.real
44+
45+
def _call(self, x: torch.Tensor) -> torch.Tensor:
46+
return symexp(x)
47+
48+
def _inverse(self, y: torch.Tensor) -> torch.Tensor:
49+
return symlog(y)
50+
51+
def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
52+
# For y = symexp(x), dy/dx = exp(|x|)
53+
# log|dy/dx| = log(exp(|x|)) = |x|
54+
return torch.abs(x)
55+
56+
@property
57+
def sign(self) -> int:
58+
"""The sign of the transform (always positive for symexp)."""
59+
return 1
60+
61+
62+
class SymLogDistribution(TransformedDistribution):
63+
"""
64+
A distribution that transforms a Normal distribution with a symexp transform.
65+
66+
This distribution is useful for modeling data with a wide dynamic range,
67+
where the data can be both positive and negative, and can have values
68+
close to zero. The symlog transform compresses large values and expands
69+
small values, making the distribution more stable for optimization.
70+
71+
Args:
72+
loc (torch.Tensor): The mean of the base Normal distribution.
73+
scale (torch.Tensor): The standard deviation of the base Normal distribution.
74+
validate_args (bool, optional): Whether to validate the arguments.
75+
Defaults to None.
76+
"""
77+
78+
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
79+
80+
def __init__(
81+
self,
82+
loc: torch.Tensor,
83+
scale: torch.Tensor,
84+
validate_args: bool | None = None,
85+
) -> None:
86+
self._loc = loc
87+
self._scale = scale
88+
base_dist = Normal(loc, scale)
89+
super().__init__(base_dist, SymexpTransform(), validate_args=validate_args)
90+
91+
@property
92+
def loc(self) -> torch.Tensor:
93+
return self._loc
94+
95+
@property
96+
def scale(self) -> torch.Tensor:
97+
return self._scale
98+
99+
@property
100+
def mean(self) -> torch.Tensor:
101+
"""Approximated by mode for now, as per instructions."""
102+
return self.mode
103+
104+
@property
105+
def mode(self) -> torch.Tensor:
106+
"""The mode of the distribution."""
107+
return symexp(self._loc)
108+
109+
def expand(
110+
self, batch_shape: Any, _instance: SymLogDistribution | None = None
111+
) -> SymLogDistribution:
112+
"""
113+
Returns a new distribution instance with expanded batch shape.
114+
115+
Args:
116+
batch_shape (Any): The new batch shape.
117+
_instance (SymLogDistribution, optional): The instance to expand.
118+
Defaults to None.
119+
120+
Returns:
121+
SymLogDistribution: The expanded distribution.
122+
"""
123+
new = self._get_checked_instance(SymLogDistribution, _instance)
124+
batch_shape = torch.Size(batch_shape)
125+
new._loc = self._loc.expand(batch_shape)
126+
new._scale = self._scale.expand(batch_shape)
127+
base_dist = Normal(new._loc, new._scale)
128+
super(SymLogDistribution, new).__init__(
129+
base_dist, SymexpTransform(), validate_args=False
130+
)
131+
new._validate_args = self._validate_args
132+
return new

src/tensorcontainer/distributions/truncated_normal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(
1414
high: Tensor,
1515
eps: float = 1e-6,
1616
validate_args=None,
17-
):
17+
) -> None:
1818
super().__init__(loc, scale, validate_args)
1919
self.low = low
2020
self.high = high

src/tensorcontainer/tensor_annotated.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
from typing import Any, Iterable, TypeVar, Union, get_args
44

5-
import torch
65
from torch import Tensor
76
from torch.utils import _pytree as pytree
87
from typing_extensions import Self
98

109
from tensorcontainer.tensor_container import TensorContainer
10+
from tensorcontainer.types import DeviceLike, ShapeLike
1111
from tensorcontainer.utils import PytreeRegistered
1212

1313
TDCompatible = Union[Tensor, TensorContainer]
@@ -18,12 +18,8 @@
1818

1919

2020
class TensorAnnotated(TensorContainer, PytreeRegistered):
21-
def __init__(
22-
self,
23-
shape: torch.Size | list[int] | tuple[int, ...],
24-
device: str | torch.device | int | None,
25-
):
26-
super().__init__(shape, device, True)
21+
def __init__(self, shape: ShapeLike, device: DeviceLike | None):
22+
super().__init__(shape, device)
2723

2824
@classmethod
2925
def _get_annotations(cls, base_cls):

src/tensorcontainer/tensor_container.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,16 @@
55
import threading
66
from abc import abstractmethod
77
from contextlib import contextmanager
8-
from typing import (
9-
Any,
10-
Callable,
11-
Iterable,
12-
List,
13-
Optional,
14-
Tuple,
15-
Type,
16-
Union,
17-
)
8+
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union
189

1910
import torch
2011

21-
# Use the official PyTree utility from torch
2212
import torch.utils._pytree as pytree
2313
from torch import Tensor
24-
from torch._prims_common import DeviceLikeType, ShapeType
2514
from torch.utils._pytree import Context, KeyEntry, PyTree
2615
from typing_extensions import Self, TypeAlias
2716

17+
from tensorcontainer.types import DeviceLike, ShapeLike
2818
from tensorcontainer.utils import resolve_device
2919

3020
HANDLED_FUNCTIONS = {}
@@ -229,25 +219,23 @@ class MyContainer(TensorContainer, PytreeRegistered):
229219
>>> first_batch = container[0] # Shape becomes (3,), events preserved
230220
"""
231221

232-
shape: ShapeType
222+
shape: torch.Size
233223
device: Optional[torch.device]
234224

235225
# Thread-local storage for unsafe construction flag
236226
_validation_disabled = threading.local()
237227

238228
def __init__(
239229
self,
240-
shape: ShapeType,
241-
device: Optional[DeviceLikeType],
242-
validate_args: bool = True,
230+
shape: ShapeLike,
231+
device: DeviceLike | None,
243232
):
244233
super().__init__()
245234

246-
self.shape = shape
247-
self.device = None if device is None else torch.device(resolve_device(device))
235+
self.shape = torch.Size(shape)
236+
self.device = None if device is None else resolve_device(device)
248237

249-
if validate_args:
250-
self._validate()
238+
self._validate()
251239

252240
@classmethod
253241
@contextmanager
@@ -378,7 +366,7 @@ def get_number_of_consuming_dims(self, item) -> int:
378366

379367
return 1
380368

381-
def transform_ellipsis_index(self, shape: tuple[int, ...], idx: tuple) -> tuple:
369+
def transform_ellipsis_index(self, shape: torch.Size, idx: tuple) -> tuple:
382370
"""
383371
Transforms an indexing tuple with an ellipsis into an equivalent one without it.
384372
...
@@ -465,7 +453,7 @@ def _format_item(key, value):
465453
# Assemble the final, properly formatted representation string
466454
return (
467455
f"{self.__class__.__name__}(\n"
468-
f"{indent}shape={str(self.shape)},\n"
456+
f"{indent}shape={tuple(self.shape)},\n"
469457
f"{indent}device={self.device},\n"
470458
f"{indent}items=\n{textwrap.indent(indented_items, indent)}\n{indent}\n"
471459
f")"
@@ -549,7 +537,7 @@ def __setitem__(self: Self, index: Any, value: Self) -> None:
549537
v[processed_index] = k.get(value)
550538
except Exception as e:
551539
raise type(e)(
552-
f"Issue with key {str(k)} and index {processed_index} for value of shape {v.shape} and type {type(v)} and assignment of shape {value.shape}"
540+
f"Issue with key {str(k)} and index {processed_index} for value of shape {v.shape} and type {type(v)} and assignment of shape {tuple(value.shape)}"
553541
) from e
554542

555543
def view(self: Self, *shape: int) -> Self:
@@ -753,7 +741,7 @@ def unsqueeze(self: Self, dim: int) -> Self:
753741

754742
def size(self) -> torch.Size:
755743
"""Returns the size of the batch dimensions."""
756-
return torch.Size(self.shape)
744+
return self.shape
757745

758746
def dim(self) -> int:
759747
"""Returns the number of batch dimensions."""

src/tensorcontainer/tensor_dataclass.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
import torch
99
from torch import Tensor
10+
from tensorcontainer.types import DeviceLike, ShapeLike
1011
from typing_extensions import dataclass_transform
1112

1213
from tensorcontainer.tensor_annotated import TensorAnnotated
13-
from tensorcontainer.tensor_container import ShapeType, TensorContainer
14+
from tensorcontainer.tensor_container import TensorContainer
1415

1516
TDCompatible = Union[Tensor, TensorContainer]
1617
DATACLASS_ARGS = {"init", "repr", "eq", "order", "unsafe_hash", "frozen", "slots"}
@@ -184,8 +185,8 @@ class FinalData(ExtendedData):
184185
# can enable static analyzers to provide type hints in IDEs. Both are programmatically
185186
# added in __init_subclass__ so removing the following two lines will only remove the
186187
# type hints, but the class will stay functional.
187-
shape: ShapeType
188-
device: Optional[torch.device]
188+
shape: ShapeLike
189+
device: DeviceLike
189190

190191
def __init_subclass__(cls, **kwargs):
191192
"""Automatically convert subclasses into dataclasses with proper field inheritance.
@@ -215,7 +216,7 @@ def __init_subclass__(cls, **kwargs):
215216
annotations = cls._get_annotations(TensorDataClass)
216217

217218
cls.__annotations__ = {
218-
"shape": torch.Size,
219+
"shape": ShapeLike,
219220
"device": Optional[torch.device],
220221
**annotations,
221222
}

0 commit comments

Comments
 (0)