Skip to content

Commit ddeb194

Browse files
author
Tim Joseph
committed
docs: update README and restructure documentation
- Updated main README.md with improved structure, removed redundant sections, and enhanced clarity - Added comprehensive user guide in docs/user_guide/README.md with detailed examples and concepts - Reorganized examples documentation with new overview and navigation - Removed outdated examples/tensor_dataclass/README.md
1 parent a6f7138 commit ddeb194

File tree

4 files changed

+79
-236
lines changed

4 files changed

+79
-236
lines changed

README.md

Lines changed: 79 additions & 219 deletions
Original file line numberDiff line numberDiff line change
@@ -3,78 +3,63 @@
33
*Tensor containers for PyTorch with PyTree compatibility and torch.compile optimization*
44

55
[![Docs](https://img.shields.io/static/v1?logo=github&style=flat&color=pink&label=docs&message=tensorcontainer)](tree/main/docs)
6-
[![Python 3.9, 3.10, 3.11, 3.12](https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.11%20%7C%203.12-blue)](https://www.python.org/downloads/)
6+
[![Python 3.9, 3.10, 3.11, 3.12](https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue)](https://www.python.org/downloads/)
77
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
88
[![PyTorch](https://img.shields.io/badge/PyTorch-2.6+-blue.svg)](https://pytorch.org/)
99
<a href="https://pypi.org/project/tensorcontainer"><img src="https://img.shields.io/pypi/v/tensorcontainer" alt="pypi version"></a>
1010

1111

12-
> **⚠️ Academic Research Project**: This project exists solely for academic purposes to explore and learn PyTorch internals. For production use, please use the official, well-maintained [**torch/tensordict**](https://github.com/pytorch/tensordict) library.
13-
1412
Tensor Container provides efficient, type-safe tensor container implementations for PyTorch workflows. It includes PyTree integration and torch.compile optimization for batched tensor operations.
1513

16-
The library includes tensor containers, probabilistic distributions, and batch/event dimension semantics for machine learning workflows.
14+
The library includes tensor containers (dict, dataclass) and distributions (torch.distributions equivalent).
1715

1816
## What is TensorContainer?
1917

20-
TensorContainer transforms how you work with structured tensor data in PyTorch by providing **tensor-like operations for entire data structures**. Instead of manually managing individual tensors across devices, batch dimensions, and nested hierarchies, TensorContainer lets you treat complex data as unified entities that behave just like regular tensors.
21-
22-
### 🚀 **Unified Operations Across Data Types**
23-
24-
Apply tensor operations like `view()`, `permute()`, `detach()`, and device transfers to entire data structures—no matter how complex:
25-
26-
```python
27-
# Single operation transforms entire distribution
28-
distribution = distribution.view(2, 3, 4).permute(1, 0, 2).detach()
29-
30-
# Works seamlessly across TensorDict, TensorDataClass, and TensorDistribution
31-
data = data.to('cuda').reshape(batch_size, -1).clone()
32-
```
33-
34-
### 🔄 **Drop-in Compatibility with PyTorch**
18+
TensorContainer transforms how you work with structured tensor data in PyTorch by providing **tensor-like operations for entire data structures**. Instead of manually managing individual tensors, TensorContainer lets you treat complex data as unified entities that behave just like regular tensors.
3519

36-
TensorContainer integrates seamlessly with existing PyTorch workflows:
37-
- **torch.distributions compatibility**: TensorDistribution is API-compatible with `torch.distributions` while adding tensor-like operations
38-
- **PyTree support**: All containers work with `torch.utils._pytree` operations and `torch.compile`
39-
- **Zero learning curve**: If you know PyTorch tensors, you already know TensorContainer
20+
### **Core Benefits**
4021

41-
### **Eliminates Boilerplate Code**
22+
- **Unified Operations**: Apply tensor operations like `view()`, `permute()`, `detach()`, and device transfers to entire data structures
23+
- **Drop-in Compatibility**: Seamless integration with existing PyTorch workflows and `torch.compile`
24+
- **Zero Boilerplate**: Eliminate manual parameter handling and type-specific operations
25+
- **Type Safety**: Full IDE support with static typing and autocomplete
4226

43-
Compare the complexity difference:
44-
45-
**With torch.distributions** (manual parameter handling):
4627
```python
47-
# Requires type-specific parameter extraction and reconstruction
48-
if isinstance(dist, Normal):
49-
detached = Normal(loc=dist.loc.detach(), scale=dist.scale.detach())
50-
elif isinstance(dist, Categorical):
51-
detached = Categorical(logits=dist.logits.detach())
52-
# ... more type checks needed
53-
```
28+
data = TensorDict(
29+
{"a": torch.rand(24), "b": torch.rand(24)},
30+
shape=(24,),
31+
device="cpu"
32+
)
5433

55-
**With TensorDistribution** (unified interface):
56-
```python
57-
# Works for any distribution type
58-
detached = dist.detach()
34+
# Single operation transforms entire structure
35+
data = data.view(2, 3, 4).permute(1, 0, 2).to('cuda').detach()
5936
```
6037

61-
### 🏗️ **Structured Data Made Simple**
38+
### **Key Features**
6239

63-
Handle complex, nested tensor structures with the same ease as single tensors:
64-
- **Batch semantics**: Consistent shape handling across all nested tensors
65-
- **Device management**: Move entire structures between CPU/GPU with single operations
66-
- **Shape validation**: Automatic verification of tensor compatibility
67-
- **Type safety**: Full IDE support with static typing and autocomplete
40+
- **⚡ JIT Compilation**: Designed for `torch.compile` with `fullgraph=True`, minimizing graph breaks and maximizing performance
41+
- **📐 Batch/Event Semantics**: Clear distinction between batch dimensions (consistent across tensors) and event dimensions (tensor-specific)
42+
- **🔄 Device Management**: Move entire structures between CPU/GPU with single operations and flexible device compatibility
43+
- **🔒 Type Safety**: Full IDE support with static typing and autocomplete
44+
- **🏗️ Multiple Container Types**: Three specialized containers for different use cases:
45+
- `TensorDict` for dynamic, dictionary-style data collections
46+
- `TensorDataClass` for type-safe, dataclass-based structures
47+
- `TensorDistribution` for probabilistic modeling with 40+ probability distributions
48+
- **🔧 Advanced Operations**: Full PyTorch tensor operations support including `view()`, `permute()`, `stack()`, `cat()`, and more
49+
- **🎯 Advanced Indexing**: Complete PyTorch indexing semantics with boolean masks, tensor indices, and ellipsis support
50+
- **📊 Shape Validation**: Automatic verification of tensor compatibility with detailed error messages
51+
- **🌳 Nested Structure Support**: Create nested structure with different TensorContainers
6852

69-
TensorContainer doesn't just store your data—it makes working with structured tensors as intuitive as working with individual tensors, while maintaining full compatibility with the PyTorch ecosystem you already know.
7053

7154
## Table of Contents
7255

56+
- [What is TensorContainer?](#what-is-tensorcontainer)
7357
- [Installation](#installation)
7458
- [Quick Start](#quick-start)
7559
- [Features](#features)
7660
- [API Overview](#api-overview)
7761
- [torch.compile Compatibility](#torchcompile-compatibility)
62+
- [Examples](#examples)
7863
- [Contributing](#contributing)
7964
- [Documentation](#documentation)
8065
- [License](#license)
@@ -83,227 +68,102 @@ TensorContainer doesn't just store your data—it makes working with structured
8368

8469
## Installation
8570

86-
### From Source (Development)
71+
### Using pip
8772

8873
```bash
89-
# Clone the repository
90-
git clone https://github.com/mctigger/tensor-container.git
91-
cd tensor-container
92-
93-
# Install in development mode
94-
pip install -e .
95-
96-
# Install with development dependencies
97-
pip install -e .[dev]
74+
pip install tensorcontainer
9875
```
9976

10077
### Requirements
10178

10279
- Python 3.9+
103-
- PyTorch 2.0+
80+
- PyTorch 2.6+
10481

10582
## Quick Start
10683

107-
### TensorDict: Dictionary-Style Containers
84+
TensorContainer transforms how you work with structured tensor data. Instead of managing individual tensors, you can treat entire data structures as unified entities that behave like regular tensors.
85+
86+
```python
87+
# Single operation transforms entire structure
88+
data = data.view(2, 3, 4).permute(1, 0, 2).to('cuda').detach()
89+
```
90+
91+
### 1. TensorDict: Dynamic Data Collections
92+
93+
Perfect for reinforcement learning data and dynamic collections:
10894

10995
```python
11096
import torch
11197
from tensorcontainer import TensorDict
11298

113-
# Create a TensorDict with batch semantics
99+
# Create a container for RL training data
114100
data = TensorDict({
115101
'observations': torch.randn(32, 128),
116102
'actions': torch.randn(32, 4),
117103
'rewards': torch.randn(32, 1)
118-
}, shape=(32,), device='cpu')
104+
}, shape=(32,))
119105

120-
# Dictionary-like access
106+
# Dictionary-like access with tensor operations
121107
obs = data['observations']
122-
data['new_field'] = torch.zeros(32, 10)
108+
data['advantages'] = torch.randn(32, 1) # Add new fields dynamically
123109

124110
# Batch operations work seamlessly
125-
stacked_data = torch.stack([data, data]) # Shape: (2, 32)
111+
batch = torch.stack([data, data]) # Shape: (2, 32)
126112
```
127113

128-
### TensorDataClass: Type-Safe Containers
114+
### 2. TensorDataClass: Type-Safe Structures
115+
116+
Ideal for model inputs and structured data with compile-time safety:
129117

130118
```python
131119
import torch
132120
from tensorcontainer import TensorDataClass
133121

134-
class RLData(TensorDataClass):
135-
observations: torch.Tensor
136-
actions: torch.Tensor
137-
rewards: torch.Tensor
122+
class ModelInput(TensorDataClass):
123+
features: torch.Tensor
124+
labels: torch.Tensor
138125

139126
# Create with full type safety and IDE support
140-
data = RLData(
141-
observations=torch.randn(32, 128),
142-
actions=torch.randn(32, 4),
143-
rewards=torch.randn(32, 1),
144-
shape=(32,),
145-
device='cpu'
127+
batch = ModelInput(
128+
features=torch.randn(32, 64, 784),
129+
labels=torch.randint(0, 10, (32, 64)),
130+
shape=(32, 64)
146131
)
147132

148-
# Type-safe field access with autocomplete
149-
obs = data.observations
150-
data.actions = torch.randn(32, 8) # Type-checked assignment
151-
```
152-
153-
### TensorDistribution: Probabilistic Containers
133+
# Unified operations on entire structure - reshape all tensors at once
134+
batch = batch.view(2048)
154135

155-
```python
156-
import torch
157-
from tensorcontainer import TensorDistribution
158-
159-
# Built-in distribution types
160-
from tensorcontainer.tensor_distribution import (
161-
TensorNormal, TensorBernoulli, TensorCategorical,
162-
TensorTruncatedNormal, TensorTanhNormal
163-
)
164-
165-
# Create probabilistic tensor containers
166-
normal_dist = TensorNormal(
167-
loc=torch.zeros(32, 4),
168-
scale=torch.ones(32, 4),
169-
shape=(32,),
170-
device='cpu'
171-
)
172-
173-
# Sample and compute probabilities
174-
samples = normal_dist.sample() # Shape: (32, 4)
175-
log_probs = normal_dist.log_prob(samples)
176-
entropy = normal_dist.entropy()
177-
178-
# Categorical distributions for discrete actions
179-
categorical = TensorCategorical(
180-
logits=torch.randn(32, 6), # 6 possible actions
181-
shape=(32,),
182-
device='cpu'
183-
)
136+
# Type-safe access with autocomplete works on reshaped data too
137+
loss = torch.nn.functional.cross_entropy(batch.features, batch.labels)
184138
```
185139

186-
### PyTree Operations
140+
### 3. TensorDistribution: Probabilistic Modeling
187141

188-
```python
189-
# All containers work seamlessly with PyTree operations
190-
import torch.utils._pytree as pytree
191-
192-
# Transform all tensors in the container
193-
doubled_data = pytree.tree_map(lambda x: x * 2, data)
194-
195-
# Combine multiple containers
196-
combined = pytree.tree_map(lambda x, y: x + y, data1, data2)
197-
```
198-
199-
## Features
200-
201-
- **torch.compile Optimized**: Compatible with PyTorch's JIT compiler
202-
- **PyTree Support**: Integration with `torch.utils._pytree` for tree operations
203-
- **Zero-Copy Operations**: Efficient tensor sharing and manipulation
204-
- **Type Safety**: Static typing support with IDE autocomplete and type checking
205-
- **Batch Semantics**: Consistent batch/event dimension handling
206-
- **Shape Validation**: Automatic validation of tensor shapes and device consistency
207-
- **Multiple Container Types**: Different container types for different use cases
208-
- **Probabilistic Support**: Distribution containers for probabilistic modeling
209-
- **Comprehensive Testing**: Extensive test suite with compile compatibility verification
210-
- **Memory Efficient**: Optimized memory usage with slots-based dataclasses
211-
212-
## API Overview
213-
214-
### Core Components
215-
216-
- **`TensorContainer`**: Base class providing core tensor manipulation operations with batch/event dimension semantics
217-
- **`TensorDict`**: Dictionary-like container for dynamic tensor collections with nested structure support
218-
- **`TensorDataClass`**: DataClass-based container for static, typed tensor structures
219-
- **`TensorDistribution`**: Distribution wrapper for probabilistic tensor operations
220-
221-
### Key Concepts
222-
223-
- **Batch Dimensions**: Leading dimensions defined by the `shape` parameter, consistent across all tensors
224-
- **Event Dimensions**: Trailing dimensions beyond batch shape, can vary per tensor
225-
- **PyTree Integration**: All containers are registered PyTree nodes for seamless tree operations
226-
- **Device Consistency**: Automatic validation ensures all tensors reside on compatible devices
227-
- **Unsafe Construction**: Context manager for performance-critical scenarios with validation bypass
228-
229-
## torch.compile Compatibility
230-
231-
Tensor Container is designed for `torch.compile` compatibility:
142+
Streamline probabilistic computations in reinforcement learning and generative models:
232143

233144
```python
234-
@torch.compile
235-
def process_batch(data: TensorDict) -> TensorDict:
236-
# PyTree operations compile efficiently
237-
return TensorContainer._tree_map(lambda x: torch.relu(x), data)
238-
239-
@torch.compile
240-
def sample_and_score(dist: TensorNormal, actions: torch.Tensor) -> torch.Tensor:
241-
# Distribution operations are compile-safe
242-
return dist.log_prob(actions)
243-
244-
# All operations compile efficiently with minimal graph breaks
245-
compiled_result = process_batch(tensor_dict)
246-
log_probs = sample_and_score(normal_dist, action_tensor)
247-
```
248-
249-
The testing framework includes compile compatibility verification to ensure operations work efficiently under JIT compilation, including:
250-
- Graph break detection and minimization
251-
- Recompilation tracking
252-
- Memory leak prevention
253-
- Performance benchmarking
254-
255-
## Contributing
256-
257-
Contributions are welcome! Tensor Container is a learning project for exploring PyTorch internals and tensor container implementations.
258-
259-
### Development Setup
260-
261-
```bash
262-
# Clone and install in development mode
263-
git clone https://github.com/mctigger/tensor-container.git
264-
cd tensor-container
265-
pip install -e .[dev]
266-
```
267-
268-
### Running Tests
269-
270-
```bash
271-
# Run all tests with coverage
272-
pytest --strict-markers --cov=src
145+
import torch
146+
from tensorcontainer.tensor_distribution import TensorNormal
273147

274-
# Run specific test modules
275-
pytest tests/tensor_dict/test_compile.py
276-
pytest tests/tensor_dataclass/
277-
pytest tests/tensor_distribution/
148+
normal = TensorNormal(
149+
loc=torch.zeros(100, 4),
150+
scale=torch.ones(100, 4)
151+
)
278152

279-
# Run compile-specific tests
280-
pytest tests/tensor_dict/test_graph_breaks.py
281-
pytest tests/tensor_dict/test_recompilations.py
153+
# With torch.distributions we need to extract the parameters, detach them
154+
# and create a new Normal distribution. With TensorDistribution we just call
155+
# .detach() on the distribution. We can also apply other tensor operations,
156+
# such as .view()!
157+
detached_normal = normal.detach()
282158
```
283159

284-
### Development Guidelines
285-
286-
- All new features must maintain `torch.compile` compatibility
287-
- Comprehensive tests required, including compile compatibility verification
288-
- Follow existing code patterns and typing conventions
289-
- Distribution implementations must support KL divergence registration
290-
- Memory efficiency considerations for large-scale tensor operations
291-
- Unsafe construction patterns for performance-critical paths
292-
293-
### Contribution Process
294-
295-
1. Fork the repository
296-
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
297-
3. Make your changes with appropriate tests
298-
4. Ensure all tests pass and maintain coverage
299-
5. Submit a pull request with a clear description
300-
301160
## Documentation
302161

303-
The project includes documentation:
162+
The project includes comprehensive documentation:
304163

305-
- **[`docs/compatibility.md`](docs/compatibility.md)**: Python version compatibility guide and best practices
306-
- **[`docs/testing.md`](docs/testing.md)**: Testing philosophy, standards, and guidelines
164+
- **[`docs/user_guide/overview.md`](docs/user_guide/overview.md)**: Complete user guide with examples and best practices
165+
- **[`docs/developer_guide/compatibility.md`](docs/developer_guide/compatibility.md)**: Python version compatibility guide and best practices
166+
- **[`docs/developer_guide/testing.md`](docs/developer_guide/testing.md)**: Testing philosophy, standards, and guidelines
307167
- **Source Code Documentation**: Extensive docstrings and type annotations throughout the codebase
308168
- **Test Coverage**: 643+ tests covering all major functionality with 86% code coverage
309169

File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)