Skip to content

Commit a8eca59

Browse files
authored
Merge pull request #3 from mctigger/add-tensor-distributions
TensorDistribution: Release version
2 parents 8526951 + 51c7eca commit a8eca59

File tree

143 files changed

+10254
-1864
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

143 files changed

+10254
-1864
lines changed

README.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,59 @@ Tensor Container provides efficient, type-safe tensor container implementations
1212

1313
The library includes tensor containers, probabilistic distributions, and batch/event dimension semantics for machine learning workflows.
1414

15+
## What is TensorContainer?
16+
17+
TensorContainer transforms how you work with structured tensor data in PyTorch by providing **tensor-like operations for entire data structures**. Instead of manually managing individual tensors across devices, batch dimensions, and nested hierarchies, TensorContainer lets you treat complex data as unified entities that behave just like regular tensors.
18+
19+
### 🚀 **Unified Operations Across Data Types**
20+
21+
Apply tensor operations like `view()`, `permute()`, `detach()`, and device transfers to entire data structures—no matter how complex:
22+
23+
```python
24+
# Single operation transforms entire distribution
25+
distribution = distribution.view(2, 3, 4).permute(1, 0, 2).detach()
26+
27+
# Works seamlessly across TensorDict, TensorDataClass, and TensorDistribution
28+
data = data.to('cuda').reshape(batch_size, -1).clone()
29+
```
30+
31+
### 🔄 **Drop-in Compatibility with PyTorch**
32+
33+
TensorContainer integrates seamlessly with existing PyTorch workflows:
34+
- **torch.distributions compatibility**: TensorDistribution is API-compatible with `torch.distributions` while adding tensor-like operations
35+
- **PyTree support**: All containers work with `torch.utils._pytree` operations and `torch.compile`
36+
- **Zero learning curve**: If you know PyTorch tensors, you already know TensorContainer
37+
38+
### **Eliminates Boilerplate Code**
39+
40+
Compare the complexity difference:
41+
42+
**With torch.distributions** (manual parameter handling):
43+
```python
44+
# Requires type-specific parameter extraction and reconstruction
45+
if isinstance(dist, Normal):
46+
detached = Normal(loc=dist.loc.detach(), scale=dist.scale.detach())
47+
elif isinstance(dist, Categorical):
48+
detached = Categorical(logits=dist.logits.detach())
49+
# ... more type checks needed
50+
```
51+
52+
**With TensorDistribution** (unified interface):
53+
```python
54+
# Works for any distribution type
55+
detached = dist.detach()
56+
```
57+
58+
### 🏗️ **Structured Data Made Simple**
59+
60+
Handle complex, nested tensor structures with the same ease as single tensors:
61+
- **Batch semantics**: Consistent shape handling across all nested tensors
62+
- **Device management**: Move entire structures between CPU/GPU with single operations
63+
- **Shape validation**: Automatic verification of tensor compatibility
64+
- **Type safety**: Full IDE support with static typing and autocomplete
65+
66+
TensorContainer doesn't just store your data—it makes working with structured tensors as intuitive as working with individual tensors, while maintaining full compatibility with the PyTorch ecosystem you already know.
67+
1568
## Table of Contents
1669

1770
- [Installation](#installation)

docs/compatibility.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,4 @@ else:
7373
pass
7474
```
7575

76+
Version-specific features should be avoided whenever possible. Instead, use one of the above solutions.

docs/tensor_annotated.md

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# TensorAnnotated Usage Guide
2+
3+
`TensorAnnotated` is a powerful base class designed to facilitate the creation of custom data structures that seamlessly integrate with PyTorch's PyTree mechanism. By subclassing `TensorAnnotated` and using type annotations for your attributes, you can define complex objects that behave like native PyTorch tensors in operations such as `copy()`, `to()`, `cuda()`, and more.
4+
5+
## What `TensorAnnotated` Offers
6+
7+
The core purpose of `TensorAnnotated` is to enable automatic PyTree flattening and unflattening for custom Python classes. This means:
8+
9+
* **Automatic Tensor Handling:** Any attribute type-annotated as a `torch.Tensor` or another `TensorContainer` (like `TensorDict` or another `TensorAnnotated` instance) will be automatically included in PyTree operations. This allows for easy movement of data between devices, cloning, and other tensor-centric manipulations.
10+
* **Structured Data with PyTorch Integration:** You can define rich, domain-specific data structures (e.g., a `RobotState` class with `position: Tensor`, `velocity: Tensor`, `joint_angles: Tensor`) that still benefit from PyTorch's ecosystem.
11+
* **Metadata Preservation:** Attributes that are type-annotated but are *not* tensors (e.g., integers, strings, lists) are treated as metadata and are preserved during PyTree operations, ensuring your object's non-tensor state is maintained.
12+
13+
## How to Use `TensorAnnotated`
14+
15+
To use `TensorAnnotated`, you need to subclass it and define your tensor attributes using type annotations.
16+
17+
### 1. Subclass `TensorAnnotated`
18+
19+
Begin by inheriting from `TensorAnnotated`.
20+
21+
```python
22+
from torch import Tensor
23+
from tensorcontainer.tensor_annotated import TensorAnnotated
24+
25+
class MyCustomData(TensorAnnotated):
26+
# ... define attributes and __init__
27+
pass
28+
```
29+
30+
### 2. Define Annotated Attributes
31+
32+
Declare your attributes with type hints. For attributes you want to be part of PyTree operations, use `torch.Tensor` or `TensorContainer` types. For metadata, use any other Python type.
33+
34+
```python
35+
from torch import Tensor
36+
from tensorcontainer.tensor_annotated import TensorAnnotated
37+
38+
class MyCustomData(TensorAnnotated):
39+
my_tensor: Tensor
40+
my_other_tensor: Tensor
41+
my_metadata: str
42+
my_number: int
43+
44+
def __init__(self, my_tensor: Tensor, my_other_tensor: Tensor, my_metadata: str, my_number: int):
45+
self.my_tensor = my_tensor
46+
self.my_other_tensor = my_other_tensor
47+
self.my_metadata = my_metadata
48+
self.my_number = my_number
49+
# IMPORTANT: Call super().__init__
50+
super().__init__(shape=my_tensor.shape, device=my_tensor.device)
51+
52+
# Example Instantiation
53+
import torch
54+
data_instance = MyCustomData(
55+
my_tensor=torch.randn(3, 4),
56+
my_other_tensor=torch.ones(3, 4),
57+
my_metadata="example",
58+
my_number=123
59+
)
60+
61+
print(data_instance.my_tensor)
62+
print(data_instance.my_metadata)
63+
```
64+
65+
### 3. Call `super().__init__`
66+
67+
It is **crucial** to call `super().__init__(shape, device)` in your subclass's `__init__` method. This initializes the underlying `TensorContainer` and sets up the necessary `shape` and `device` properties for your `TensorAnnotated` instance. The `shape` and `device` should typically be derived from one of your primary tensor attributes.
68+
69+
```python
70+
class MyCustomData(TensorAnnotated):
71+
my_tensor: Tensor
72+
73+
def __init__(self, my_tensor: Tensor):
74+
self.my_tensor = my_tensor
75+
# Correct way to call super().__init__
76+
super().__init__(shape=my_tensor.shape, device=my_tensor.device)
77+
```
78+
79+
### Example with PyTree Operations
80+
81+
Once instantiated, your `TensorAnnotated` object will behave like a PyTree, allowing operations like `copy()`, `to()`, `cuda()`, etc.
82+
83+
```python
84+
import torch
85+
from tensorcontainer.tensor_annotated import TensorAnnotated
86+
87+
class MyModelOutput(TensorAnnotated):
88+
logits: Tensor
89+
hidden_state: Tensor
90+
model_name: str
91+
92+
def __init__(self, logits: Tensor, hidden_state: Tensor, model_name: str):
93+
self.logits = logits
94+
self.hidden_state = hidden_state
95+
self.model_name = model_name
96+
super().__init__(shape=logits.shape, device=logits.device)
97+
98+
# Create an instance
99+
output = MyModelOutput(
100+
logits=torch.randn(10, 5),
101+
hidden_state=torch.randn(10, 128),
102+
model_name="Transformer"
103+
)
104+
105+
print(f"Original device: {output.logits.device}")
106+
print(f"Original model name: {output.model_name}")
107+
108+
# Move to CPU (if on GPU) or GPU (if on CPU)
109+
new_device = "cpu" if output.logits.is_cuda else "cuda"
110+
output_on_new_device = output.to(new_device)
111+
112+
print(f"New device: {output_on_new_device.logits.device}")
113+
print(f"New model name: {output_on_new_device.model_name}") # Metadata is preserved
114+
115+
# Create a copy
116+
output_copy = output.copy()
117+
print(f"Copy logits are same object? {output_copy.logits is output.logits}") # False, it's a deep copy
118+
print(f"Copy model name: {output_copy.model_name}")
119+
```
120+
121+
## Caveats and Limitations
122+
123+
Understanding these points is crucial for effectively using `TensorAnnotated` and avoiding unexpected behavior:
124+
125+
### 1. Only Annotated Tensors are PyTree Leaves
126+
127+
`TensorAnnotated`'s PyTree integration (flattening and unflattening) *only* considers attributes that are explicitly type-annotated as `torch.Tensor` or `TensorContainer`.
128+
129+
### 2. Attributes from Non-`TensorAnnotated` Parents are Ignored
130+
131+
If your class inherits from a parent class that does *not* subclass `TensorAnnotated`, any attributes defined in that non-`TensorAnnotated` parent will *not* be included in the PyTree operations. They will effectively be lost if you perform operations like `copy()`, `to()`, or `cuda()` on your `TensorAnnotated` instance.
132+
133+
**Example:**
134+
135+
```python
136+
import torch
137+
from torch import Tensor
138+
from tensorcontainer.tensor_annotated import TensorAnnotated
139+
140+
class NonTensorAnnotatedParent:
141+
def __init__(self, non_pytree_attr: str):
142+
self.non_pytree_attr = non_pytree_attr
143+
144+
class MyCombinedData(TensorAnnotated, NonTensorAnnotatedParent):
145+
my_tensor: Tensor
146+
147+
def __init__(self, my_tensor: Tensor, non_pytree_attr: str):
148+
self.my_tensor = my_tensor
149+
NonTensorAnnotatedParent.__init__(self, non_pytree_attr) # Call parent's init
150+
super().__init__(shape=my_tensor.shape, device=my_tensor.device)
151+
152+
data = MyCombinedData(my_tensor=torch.randn(2), non_pytree_attr="I will be lost")
153+
print(f"Original non_pytree_attr: {data.non_pytree_attr}")
154+
155+
copied_data = data.copy()
156+
157+
# This will raise an AttributeError because non_pytree_attr was not part of the PyTree
158+
try:
159+
print(f"Copied non_pytree_attr: {copied_data.non_pytree_attr}")
160+
except AttributeError as e:
161+
print(f"Error accessing copied non_pytree_attr: {e}")
162+
```
163+
164+
### 3. Non-Annotated Attributes are Ignored
165+
166+
Any attribute assigned to `self` within your subclass's `__init__` or other methods that is *not* explicitly type-annotated will also be ignored by the PyTree mechanism. This means they will not be preserved across `copy()`, `to()`, etc.
167+
168+
```python
169+
import torch
170+
from torch import Tensor
171+
from tensorcontainer.tensor_annotated import TensorAnnotated
172+
173+
class MyDataWithNonAnnotated(TensorAnnotated):
174+
annotated_tensor: Tensor
175+
# non_annotated_value: int <-- Missing annotation
176+
177+
def __init__(self, annotated_tensor: Tensor, non_annotated_value: int):
178+
self.annotated_tensor = annotated_tensor
179+
self.non_annotated_value = non_annotated_value # This attribute is not annotated
180+
super().__init__(shape=annotated_tensor.shape, device=annotated_tensor.device)
181+
182+
data = MyDataWithNonAnnotated(annotated_tensor=torch.randn(2), non_annotated_value=10)
183+
print(f"Original non_annotated_value: {data.non_annotated_value}")
184+
185+
copied_data = data.copy()
186+
187+
# This will raise an AttributeError
188+
try:
189+
print(f"Copied non_annotated_value: {copied_data.non_annotated_value}")
190+
except AttributeError as e:
191+
print(f"Error accessing copied non_annotated_value: {e}")
192+
```
193+
194+
### 4. Importance of Calling `super().__init__`
195+
196+
Failing to call `super().__init__(shape, device)` will result in an improperly initialized `TensorAnnotated` instance. Essential properties like `shape` and `device` will not be set, and PyTree operations will likely fail or produce incorrect results.
197+
198+
### 5. Reserved Attributes: `shape` and `device`
199+
200+
The attributes `shape` and `device` are internally managed by `TensorAnnotated` (inherited from `TensorContainer`). You **cannot** define these as annotated attributes in your subclasses. Attempting to do so will result in a `TypeError`.
201+
202+
```python
203+
# This will raise a TypeError
204+
# class InvalidData(TensorAnnotated):
205+
# shape: torch.Size # ERROR: Cannot define reserved fields
206+
# my_tensor: Tensor
207+
#
208+
# def __init__(self, my_tensor: Tensor):
209+
# self.my_tensor = my_tensor
210+
# super().__init__(shape=my_tensor.shape, device=my_tensor.device)
211+
```
212+
213+
### 6. Inheritance with Multiple Parents
214+
215+
When using multiple inheritance, `TensorAnnotated` correctly collects annotations from all `TensorAnnotated` parent classes in the Method Resolution Order (MRO). However, ensure that your `__init__` method correctly calls the `__init__` of all relevant parent classes, especially the `TensorAnnotated` ones, passing `shape` and `device` appropriately.
216+
217+
```python
218+
import torch
219+
from torch import Tensor
220+
from tensorcontainer.tensor_annotated import TensorAnnotated
221+
222+
class ParentA(TensorAnnotated):
223+
tensor_a: Tensor
224+
def __init__(self, tensor_a: Tensor, **kwargs):
225+
self.tensor_a = tensor_a
226+
super().__init__(**kwargs) # Pass kwargs to allow shape/device from child
227+
228+
class ParentB(TensorAnnotated):
229+
tensor_b: Tensor
230+
def __init__(self, tensor_b: Tensor, **kwargs):
231+
self.tensor_b = tensor_b
232+
super().__init__(**kwargs) # Pass kwargs to allow shape/device from child
233+
234+
class Child(ParentA, ParentB):
235+
tensor_c: Tensor
236+
def __init__(self, tensor_a: Tensor, tensor_b: Tensor, tensor_c: Tensor):
237+
self.tensor_c = tensor_c
238+
# Call parents' inits, ensuring shape and device are passed to the ultimate TensorAnnotated init
239+
super().__init__(
240+
tensor_a=tensor_a,
241+
tensor_b=tensor_b,
242+
shape=tensor_c.shape, # Use one of the tensors for shape/device
243+
device=tensor_c.device
244+
)
245+
246+
data = Child(torch.randn(5), torch.randn(5), torch.randn(5))
247+
copied_data = data.copy()
248+
249+
assert copied_data.tensor_a is data.tensor_a
250+
assert copied_data.tensor_b is data.tensor_b
251+
assert copied_data.tensor_c is data.tensor_c

0 commit comments

Comments
 (0)