Skip to content

Commit c21b636

Browse files
authored
Merge pull request #18 from mctigger/documentation
Documentation
2 parents a8d546e + 637c29f commit c21b636

File tree

4 files changed

+80
-236
lines changed

4 files changed

+80
-236
lines changed

README.md

Lines changed: 80 additions & 219 deletions
Original file line numberDiff line numberDiff line change
@@ -3,78 +3,64 @@
33
*Tensor containers for PyTorch with PyTree compatibility and torch.compile optimization*
44

55
[![Docs](https://img.shields.io/static/v1?logo=github&style=flat&color=pink&label=docs&message=tensorcontainer)](tree/main/docs)
6-
[![Python 3.9, 3.10, 3.11, 3.12](https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.11%20%7C%203.12-blue)](https://www.python.org/downloads/)
6+
[![Documentation](https://img.shields.io/badge/docs-local-blue)](./docs/user_guide/README.md)
7+
[![Python 3.9, 3.10, 3.11, 3.12](https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue)](https://www.python.org/downloads/)
78
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
89
[![PyTorch](https://img.shields.io/badge/PyTorch-2.6+-blue.svg)](https://pytorch.org/)
910
<a href="https://pypi.org/project/tensorcontainer"><img src="https://img.shields.io/pypi/v/tensorcontainer" alt="pypi version"></a>
1011

1112

12-
> **⚠️ Academic Research Project**: This project exists solely for academic purposes to explore and learn PyTorch internals. For production use, please use the official, well-maintained [**torch/tensordict**](https://github.com/pytorch/tensordict) library.
13-
1413
Tensor Container provides efficient, type-safe tensor container implementations for PyTorch workflows. It includes PyTree integration and torch.compile optimization for batched tensor operations.
1514

16-
The library includes tensor containers, probabilistic distributions, and batch/event dimension semantics for machine learning workflows.
15+
The library includes tensor containers (dict, dataclass) and distributions (torch.distributions equivalent).
1716

1817
## What is TensorContainer?
1918

20-
TensorContainer transforms how you work with structured tensor data in PyTorch by providing **tensor-like operations for entire data structures**. Instead of manually managing individual tensors across devices, batch dimensions, and nested hierarchies, TensorContainer lets you treat complex data as unified entities that behave just like regular tensors.
21-
22-
### 🚀 **Unified Operations Across Data Types**
23-
24-
Apply tensor operations like `view()`, `permute()`, `detach()`, and device transfers to entire data structures—no matter how complex:
25-
26-
```python
27-
# Single operation transforms entire distribution
28-
distribution = distribution.view(2, 3, 4).permute(1, 0, 2).detach()
29-
30-
# Works seamlessly across TensorDict, TensorDataClass, and TensorDistribution
31-
data = data.to('cuda').reshape(batch_size, -1).clone()
32-
```
33-
34-
### 🔄 **Drop-in Compatibility with PyTorch**
19+
TensorContainer transforms how you work with structured tensor data in PyTorch by providing **tensor-like operations for entire data structures**. Instead of manually managing individual tensors, TensorContainer lets you treat complex data as unified entities that behave just like regular tensors.
3520

36-
TensorContainer integrates seamlessly with existing PyTorch workflows:
37-
- **torch.distributions compatibility**: TensorDistribution is API-compatible with `torch.distributions` while adding tensor-like operations
38-
- **PyTree support**: All containers work with `torch.utils._pytree` operations and `torch.compile`
39-
- **Zero learning curve**: If you know PyTorch tensors, you already know TensorContainer
21+
### **Core Benefits**
4022

41-
### **Eliminates Boilerplate Code**
23+
- **Unified Operations**: Apply tensor operations like `view()`, `permute()`, `detach()`, and device transfers to entire data structures
24+
- **Drop-in Compatibility**: Seamless integration with existing PyTorch workflows and `torch.compile`
25+
- **Zero Boilerplate**: Eliminate manual parameter handling and type-specific operations
26+
- **Type Safety**: Full IDE support with static typing and autocomplete
4227

43-
Compare the complexity difference:
44-
45-
**With torch.distributions** (manual parameter handling):
4628
```python
47-
# Requires type-specific parameter extraction and reconstruction
48-
if isinstance(dist, Normal):
49-
detached = Normal(loc=dist.loc.detach(), scale=dist.scale.detach())
50-
elif isinstance(dist, Categorical):
51-
detached = Categorical(logits=dist.logits.detach())
52-
# ... more type checks needed
53-
```
29+
data = TensorDict(
30+
{"a": torch.rand(24), "b": torch.rand(24)},
31+
shape=(24,),
32+
device="cpu"
33+
)
5434

55-
**With TensorDistribution** (unified interface):
56-
```python
57-
# Works for any distribution type
58-
detached = dist.detach()
35+
# Single operation transforms entire structure
36+
data = data.view(2, 3, 4).permute(1, 0, 2).to('cuda').detach()
5937
```
6038

61-
### 🏗️ **Structured Data Made Simple**
39+
### **Key Features**
6240

63-
Handle complex, nested tensor structures with the same ease as single tensors:
64-
- **Batch semantics**: Consistent shape handling across all nested tensors
65-
- **Device management**: Move entire structures between CPU/GPU with single operations
66-
- **Shape validation**: Automatic verification of tensor compatibility
67-
- **Type safety**: Full IDE support with static typing and autocomplete
41+
- **⚡ JIT Compilation**: Designed for `torch.compile` with `fullgraph=True`, minimizing graph breaks and maximizing performance
42+
- **📐 Batch/Event Semantics**: Clear distinction between batch dimensions (consistent across tensors) and event dimensions (tensor-specific)
43+
- **🔄 Device Management**: Move entire structures between CPU/GPU with single operations and flexible device compatibility
44+
- **🔒 Type Safety**: Full IDE support with static typing and autocomplete
45+
- **🏗️ Multiple Container Types**: Three specialized containers for different use cases:
46+
- `TensorDict` for dynamic, dictionary-style data collections
47+
- `TensorDataClass` for type-safe, dataclass-based structures
48+
- `TensorDistribution` for probabilistic modeling with 40+ probability distributions
49+
- **🔧 Advanced Operations**: Full PyTorch tensor operations support including `view()`, `permute()`, `stack()`, `cat()`, and more
50+
- **🎯 Advanced Indexing**: Complete PyTorch indexing semantics with boolean masks, tensor indices, and ellipsis support
51+
- **📊 Shape Validation**: Automatic verification of tensor compatibility with detailed error messages
52+
- **🌳 Nested Structure Support**: Create nested structure with different TensorContainers
6853

69-
TensorContainer doesn't just store your data—it makes working with structured tensors as intuitive as working with individual tensors, while maintaining full compatibility with the PyTorch ecosystem you already know.
7054

7155
## Table of Contents
7256

57+
- [What is TensorContainer?](#what-is-tensorcontainer)
7358
- [Installation](#installation)
7459
- [Quick Start](#quick-start)
7560
- [Features](#features)
7661
- [API Overview](#api-overview)
7762
- [torch.compile Compatibility](#torchcompile-compatibility)
63+
- [Examples](#examples)
7864
- [Contributing](#contributing)
7965
- [Documentation](#documentation)
8066
- [License](#license)
@@ -83,227 +69,102 @@ TensorContainer doesn't just store your data—it makes working with structured
8369

8470
## Installation
8571

86-
### From Source (Development)
72+
### Using pip
8773

8874
```bash
89-
# Clone the repository
90-
git clone https://github.com/mctigger/tensor-container.git
91-
cd tensor-container
92-
93-
# Install in development mode
94-
pip install -e .
95-
96-
# Install with development dependencies
97-
pip install -e .[dev]
75+
pip install tensorcontainer
9876
```
9977

10078
### Requirements
10179

10280
- Python 3.9+
103-
- PyTorch 2.0+
81+
- PyTorch 2.6+
10482

10583
## Quick Start
10684

107-
### TensorDict: Dictionary-Style Containers
85+
TensorContainer transforms how you work with structured tensor data. Instead of managing individual tensors, you can treat entire data structures as unified entities that behave like regular tensors.
86+
87+
```python
88+
# Single operation transforms entire structure
89+
data = data.view(2, 3, 4).permute(1, 0, 2).to('cuda').detach()
90+
```
91+
92+
### 1. TensorDict: Dynamic Data Collections
93+
94+
Perfect for reinforcement learning data and dynamic collections:
10895

10996
```python
11097
import torch
11198
from tensorcontainer import TensorDict
11299

113-
# Create a TensorDict with batch semantics
100+
# Create a container for RL training data
114101
data = TensorDict({
115102
'observations': torch.randn(32, 128),
116103
'actions': torch.randn(32, 4),
117104
'rewards': torch.randn(32, 1)
118-
}, shape=(32,), device='cpu')
105+
}, shape=(32,))
119106

120-
# Dictionary-like access
107+
# Dictionary-like access with tensor operations
121108
obs = data['observations']
122-
data['new_field'] = torch.zeros(32, 10)
109+
data['advantages'] = torch.randn(32, 1) # Add new fields dynamically
123110

124111
# Batch operations work seamlessly
125-
stacked_data = torch.stack([data, data]) # Shape: (2, 32)
112+
batch = torch.stack([data, data]) # Shape: (2, 32)
126113
```
127114

128-
### TensorDataClass: Type-Safe Containers
115+
### 2. TensorDataClass: Type-Safe Structures
116+
117+
Ideal for model inputs and structured data with compile-time safety:
129118

130119
```python
131120
import torch
132121
from tensorcontainer import TensorDataClass
133122

134-
class RLData(TensorDataClass):
135-
observations: torch.Tensor
136-
actions: torch.Tensor
137-
rewards: torch.Tensor
123+
class ModelInput(TensorDataClass):
124+
features: torch.Tensor
125+
labels: torch.Tensor
138126

139127
# Create with full type safety and IDE support
140-
data = RLData(
141-
observations=torch.randn(32, 128),
142-
actions=torch.randn(32, 4),
143-
rewards=torch.randn(32, 1),
144-
shape=(32,),
145-
device='cpu'
128+
batch = ModelInput(
129+
features=torch.randn(32, 64, 784),
130+
labels=torch.randint(0, 10, (32, 64)),
131+
shape=(32, 64)
146132
)
147133

148-
# Type-safe field access with autocomplete
149-
obs = data.observations
150-
data.actions = torch.randn(32, 8) # Type-checked assignment
151-
```
152-
153-
### TensorDistribution: Probabilistic Containers
134+
# Unified operations on entire structure - reshape all tensors at once
135+
batch = batch.view(2048)
154136

155-
```python
156-
import torch
157-
from tensorcontainer import TensorDistribution
158-
159-
# Built-in distribution types
160-
from tensorcontainer.tensor_distribution import (
161-
TensorNormal, TensorBernoulli, TensorCategorical,
162-
TensorTruncatedNormal, TensorTanhNormal
163-
)
164-
165-
# Create probabilistic tensor containers
166-
normal_dist = TensorNormal(
167-
loc=torch.zeros(32, 4),
168-
scale=torch.ones(32, 4),
169-
shape=(32,),
170-
device='cpu'
171-
)
172-
173-
# Sample and compute probabilities
174-
samples = normal_dist.sample() # Shape: (32, 4)
175-
log_probs = normal_dist.log_prob(samples)
176-
entropy = normal_dist.entropy()
177-
178-
# Categorical distributions for discrete actions
179-
categorical = TensorCategorical(
180-
logits=torch.randn(32, 6), # 6 possible actions
181-
shape=(32,),
182-
device='cpu'
183-
)
137+
# Type-safe access with autocomplete works on reshaped data too
138+
loss = torch.nn.functional.cross_entropy(batch.features, batch.labels)
184139
```
185140

186-
### PyTree Operations
141+
### 3. TensorDistribution: Probabilistic Modeling
187142

188-
```python
189-
# All containers work seamlessly with PyTree operations
190-
import torch.utils._pytree as pytree
191-
192-
# Transform all tensors in the container
193-
doubled_data = pytree.tree_map(lambda x: x * 2, data)
194-
195-
# Combine multiple containers
196-
combined = pytree.tree_map(lambda x, y: x + y, data1, data2)
197-
```
198-
199-
## Features
200-
201-
- **torch.compile Optimized**: Compatible with PyTorch's JIT compiler
202-
- **PyTree Support**: Integration with `torch.utils._pytree` for tree operations
203-
- **Zero-Copy Operations**: Efficient tensor sharing and manipulation
204-
- **Type Safety**: Static typing support with IDE autocomplete and type checking
205-
- **Batch Semantics**: Consistent batch/event dimension handling
206-
- **Shape Validation**: Automatic validation of tensor shapes and device consistency
207-
- **Multiple Container Types**: Different container types for different use cases
208-
- **Probabilistic Support**: Distribution containers for probabilistic modeling
209-
- **Comprehensive Testing**: Extensive test suite with compile compatibility verification
210-
- **Memory Efficient**: Optimized memory usage with slots-based dataclasses
211-
212-
## API Overview
213-
214-
### Core Components
215-
216-
- **`TensorContainer`**: Base class providing core tensor manipulation operations with batch/event dimension semantics
217-
- **`TensorDict`**: Dictionary-like container for dynamic tensor collections with nested structure support
218-
- **`TensorDataClass`**: DataClass-based container for static, typed tensor structures
219-
- **`TensorDistribution`**: Distribution wrapper for probabilistic tensor operations
220-
221-
### Key Concepts
222-
223-
- **Batch Dimensions**: Leading dimensions defined by the `shape` parameter, consistent across all tensors
224-
- **Event Dimensions**: Trailing dimensions beyond batch shape, can vary per tensor
225-
- **PyTree Integration**: All containers are registered PyTree nodes for seamless tree operations
226-
- **Device Consistency**: Automatic validation ensures all tensors reside on compatible devices
227-
- **Unsafe Construction**: Context manager for performance-critical scenarios with validation bypass
228-
229-
## torch.compile Compatibility
230-
231-
Tensor Container is designed for `torch.compile` compatibility:
143+
Streamline probabilistic computations in reinforcement learning and generative models:
232144

233145
```python
234-
@torch.compile
235-
def process_batch(data: TensorDict) -> TensorDict:
236-
# PyTree operations compile efficiently
237-
return TensorContainer._tree_map(lambda x: torch.relu(x), data)
238-
239-
@torch.compile
240-
def sample_and_score(dist: TensorNormal, actions: torch.Tensor) -> torch.Tensor:
241-
# Distribution operations are compile-safe
242-
return dist.log_prob(actions)
243-
244-
# All operations compile efficiently with minimal graph breaks
245-
compiled_result = process_batch(tensor_dict)
246-
log_probs = sample_and_score(normal_dist, action_tensor)
247-
```
248-
249-
The testing framework includes compile compatibility verification to ensure operations work efficiently under JIT compilation, including:
250-
- Graph break detection and minimization
251-
- Recompilation tracking
252-
- Memory leak prevention
253-
- Performance benchmarking
254-
255-
## Contributing
256-
257-
Contributions are welcome! Tensor Container is a learning project for exploring PyTorch internals and tensor container implementations.
258-
259-
### Development Setup
260-
261-
```bash
262-
# Clone and install in development mode
263-
git clone https://github.com/mctigger/tensor-container.git
264-
cd tensor-container
265-
pip install -e .[dev]
266-
```
267-
268-
### Running Tests
269-
270-
```bash
271-
# Run all tests with coverage
272-
pytest --strict-markers --cov=src
146+
import torch
147+
from tensorcontainer.tensor_distribution import TensorNormal
273148

274-
# Run specific test modules
275-
pytest tests/tensor_dict/test_compile.py
276-
pytest tests/tensor_dataclass/
277-
pytest tests/tensor_distribution/
149+
normal = TensorNormal(
150+
loc=torch.zeros(100, 4),
151+
scale=torch.ones(100, 4)
152+
)
278153

279-
# Run compile-specific tests
280-
pytest tests/tensor_dict/test_graph_breaks.py
281-
pytest tests/tensor_dict/test_recompilations.py
154+
# With torch.distributions we need to extract the parameters, detach them
155+
# and create a new Normal distribution. With TensorDistribution we just call
156+
# .detach() on the distribution. We can also apply other tensor operations,
157+
# such as .view()!
158+
detached_normal = normal.detach()
282159
```
283160

284-
### Development Guidelines
285-
286-
- All new features must maintain `torch.compile` compatibility
287-
- Comprehensive tests required, including compile compatibility verification
288-
- Follow existing code patterns and typing conventions
289-
- Distribution implementations must support KL divergence registration
290-
- Memory efficiency considerations for large-scale tensor operations
291-
- Unsafe construction patterns for performance-critical paths
292-
293-
### Contribution Process
294-
295-
1. Fork the repository
296-
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
297-
3. Make your changes with appropriate tests
298-
4. Ensure all tests pass and maintain coverage
299-
5. Submit a pull request with a clear description
300-
301161
## Documentation
302162

303-
The project includes documentation:
163+
The project includes comprehensive documentation:
304164

305-
- **[`docs/compatibility.md`](docs/compatibility.md)**: Python version compatibility guide and best practices
306-
- **[`docs/testing.md`](docs/testing.md)**: Testing philosophy, standards, and guidelines
165+
- **[`docs/user_guide/overview.md`](docs/user_guide/overview.md)**: Complete user guide with examples and best practices
166+
- **[`docs/developer_guide/compatibility.md`](docs/developer_guide/compatibility.md)**: Python version compatibility guide and best practices
167+
- **[`docs/developer_guide/testing.md`](docs/developer_guide/testing.md)**: Testing philosophy, standards, and guidelines
307168
- **Source Code Documentation**: Extensive docstrings and type annotations throughout the codebase
308169
- **Test Coverage**: 643+ tests covering all major functionality with 86% code coverage
309170

File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)