11# Tensor Container
22
3- * Modern tensor containers for PyTorch with PyTree compatibility and torch.compile optimization*
3+ * Tensor containers for PyTorch with PyTree compatibility and torch.compile optimization*
44
5- [ ![ Python 3.8 +] ( https://img.shields.io/badge/python-3.8 +-blue.svg )] ( https://www.python.org/downloads/ )
5+ [ ![ Python 3.9 +] ( https://img.shields.io/badge/python-3.9 +-blue.svg )] ( https://www.python.org/downloads/ )
66[ ![ License: MIT] ( https://img.shields.io/badge/License-MIT-yellow.svg )] ( https://opensource.org/licenses/MIT )
77[ ![ PyTorch] ( https://img.shields.io/badge/PyTorch-2.0+-red.svg )] ( https://pytorch.org/ )
88
99> ** ⚠️ Academic Research Project** : This project exists solely for academic purposes to explore and learn PyTorch internals. For production use, please use the official, well-maintained [ ** torch/tensordict** ] ( https://github.com/pytorch/tensordict ) library.
1010
11- Tensor Container provides efficient, type-safe tensor container implementations designed for modern PyTorch workflows. Built from the ground up with PyTree integration and torch.compile optimization, it enables seamless batched tensor operations with minimal overhead and maximum performance.
11+ Tensor Container provides efficient, type-safe tensor container implementations for PyTorch workflows. It includes PyTree integration and torch.compile optimization for batched tensor operations.
12+
13+ The library includes tensor containers, probabilistic distributions, and batch/event dimension semantics for machine learning workflows.
1214
1315## Table of Contents
1416
@@ -18,6 +20,7 @@ Tensor Container provides efficient, type-safe tensor container implementations
1820- [ API Overview] ( #api-overview )
1921- [ torch.compile Compatibility] ( #torchcompile-compatibility )
2022- [ Contributing] ( #contributing )
23+ - [ Documentation] ( #documentation )
2124- [ License] ( #license )
2225- [ Authors] ( #authors )
2326- [ Contact and Support] ( #contact-and-support )
@@ -40,7 +43,7 @@ pip install -e .[dev]
4043
4144### Requirements
4245
43- - Python 3.8 +
46+ - Python 3.9 +
4447- PyTorch 2.0+
4548
4649## Quick Start
@@ -91,6 +94,39 @@ obs = data.observations
9194data.actions = torch.randn(32 , 8 ) # Type-checked assignment
9295```
9396
97+ ### TensorDistribution: Probabilistic Containers
98+
99+ ``` python
100+ import torch
101+ from tensorcontainer import TensorDistribution
102+
103+ # Built-in distribution types
104+ from tensorcontainer.tensor_distribution import (
105+ TensorNormal, TensorBernoulli, TensorCategorical,
106+ TensorTruncatedNormal, TensorTanhNormal
107+ )
108+
109+ # Create probabilistic tensor containers
110+ normal_dist = TensorNormal(
111+ loc = torch.zeros(32 , 4 ),
112+ scale = torch.ones(32 , 4 ),
113+ shape = (32 ,),
114+ device = ' cpu'
115+ )
116+
117+ # Sample and compute probabilities
118+ samples = normal_dist.sample() # Shape: (32, 4)
119+ log_probs = normal_dist.log_prob(samples)
120+ entropy = normal_dist.entropy()
121+
122+ # Categorical distributions for discrete actions
123+ categorical = TensorCategorical(
124+ logits = torch.randn(32 , 6 ), # 6 possible actions
125+ shape = (32 ,),
126+ device = ' cpu'
127+ )
128+ ```
129+
94130### PyTree Operations
95131
96132``` python
@@ -106,21 +142,23 @@ combined = pytree.tree_map(lambda x, y: x + y, data1, data2)
106142
107143## Features
108144
109- - ** 🔥 torch.compile Optimized** : Built for maximum performance with PyTorch's JIT compiler
110- - ** 🌳 Native PyTree Support** : Seamless integration with ` torch.utils._pytree ` for tree operations
111- - ** ⚡ Zero-Copy Operations** : Efficient tensor sharing and manipulation without unnecessary copies
112- - ** 🎯 Type Safety** : Full static typing support with IDE autocomplete and type checking
113- - ** 📊 Batch Semantics** : Consistent batch/event dimension handling across all operations
114- - ** 🔍 Shape Validation** : Automatic validation of tensor shapes and device consistency
115- - ** 🏗️ Flexible Architecture** : Multiple container types for different use cases
116- - ** 🧪 Comprehensive Testing** : Extensive test suite with compile compatibility verification
145+ - ** torch.compile Optimized** : Compatible with PyTorch's JIT compiler
146+ - ** PyTree Support** : Integration with ` torch.utils._pytree ` for tree operations
147+ - ** Zero-Copy Operations** : Efficient tensor sharing and manipulation
148+ - ** Type Safety** : Static typing support with IDE autocomplete and type checking
149+ - ** Batch Semantics** : Consistent batch/event dimension handling
150+ - ** Shape Validation** : Automatic validation of tensor shapes and device consistency
151+ - ** Multiple Container Types** : Different container types for different use cases
152+ - ** Probabilistic Support** : Distribution containers for probabilistic modeling
153+ - ** Comprehensive Testing** : Extensive test suite with compile compatibility verification
154+ - ** Memory Efficient** : Optimized memory usage with slots-based dataclasses
117155
118156## API Overview
119157
120158### Core Components
121159
122- - ** ` TensorContainer ` ** : Base class providing core tensor manipulation operations
123- - ** ` TensorDict ` ** : Dictionary-like container for dynamic tensor collections
160+ - ** ` TensorContainer ` ** : Base class providing core tensor manipulation operations with batch/event dimension semantics
161+ - ** ` TensorDict ` ** : Dictionary-like container for dynamic tensor collections with nested structure support
124162- ** ` TensorDataClass ` ** : DataClass-based container for static, typed tensor structures
125163- ** ` TensorDistribution ` ** : Distribution wrapper for probabilistic tensor operations
126164
@@ -130,25 +168,37 @@ combined = pytree.tree_map(lambda x, y: x + y, data1, data2)
130168- ** Event Dimensions** : Trailing dimensions beyond batch shape, can vary per tensor
131169- ** PyTree Integration** : All containers are registered PyTree nodes for seamless tree operations
132170- ** Device Consistency** : Automatic validation ensures all tensors reside on compatible devices
171+ - ** Unsafe Construction** : Context manager for performance-critical scenarios with validation bypass
133172
134173## torch.compile Compatibility
135174
136- Tensor Container is designed from the ground up for ` torch.compile ` compatibility:
175+ Tensor Container is designed for ` torch.compile ` compatibility:
137176
138177``` python
139178@torch.compile
140179def process_batch (data : TensorDict) -> TensorDict:
141- return data.apply(lambda x : torch.relu(x))
180+ # PyTree operations compile efficiently
181+ return TensorContainer._tree_map(lambda x : torch.relu(x), data)
142182
143- # Compiles efficiently with minimal graph breaks
183+ @torch.compile
184+ def sample_and_score (dist : TensorNormal, actions : torch.Tensor) -> torch.Tensor:
185+ # Distribution operations are compile-safe
186+ return dist.log_prob(actions)
187+
188+ # All operations compile efficiently with minimal graph breaks
144189compiled_result = process_batch(tensor_dict)
190+ log_probs = sample_and_score(normal_dist, action_tensor)
145191```
146192
147- Our testing framework includes comprehensive compile compatibility verification to ensure all operations work efficiently under JIT compilation.
193+ The testing framework includes compile compatibility verification to ensure operations work efficiently under JIT compilation, including:
194+ - Graph break detection and minimization
195+ - Recompilation tracking
196+ - Memory leak prevention
197+ - Performance benchmarking
148198
149199## Contributing
150200
151- We welcome contributions ! Tensor Container is a learning project for exploring PyTorch internals and tensor container implementations.
201+ Contributions are welcome ! Tensor Container is a learning project for exploring PyTorch internals and tensor container implementations.
152202
153203### Development Setup
154204
@@ -168,13 +218,21 @@ pytest --strict-markers --cov=src
168218# Run specific test modules
169219pytest tests/tensor_dict/test_compile.py
170220pytest tests/tensor_dataclass/
221+ pytest tests/tensor_distribution/
222+
223+ # Run compile-specific tests
224+ pytest tests/tensor_dict/test_graph_breaks.py
225+ pytest tests/tensor_dict/test_recompilations.py
171226```
172227
173228### Development Guidelines
174229
175230- All new features must maintain ` torch.compile ` compatibility
176231- Comprehensive tests required, including compile compatibility verification
177232- Follow existing code patterns and typing conventions
233+ - Distribution implementations must support KL divergence registration
234+ - Memory efficiency considerations for large-scale tensor operations
235+ - Unsafe construction patterns for performance-critical paths
178236
179237### Contribution Process
180238
@@ -184,13 +242,22 @@ pytest tests/tensor_dataclass/
1842424 . Ensure all tests pass and maintain coverage
1852435 . Submit a pull request with a clear description
186244
245+ ## Documentation
246+
247+ The project includes documentation:
248+
249+ - ** [ ` docs/compatibility.md ` ] ( docs/compatibility.md ) ** : Python version compatibility guide and best practices
250+ - ** [ ` docs/testing.md ` ] ( docs/testing.md ) ** : Testing philosophy, standards, and guidelines
251+ - ** Source Code Documentation** : Extensive docstrings and type annotations throughout the codebase
252+ - ** Test Coverage** : 643+ tests covering all major functionality with 86% code coverage
253+
187254## License
188255
189256This project is licensed under the MIT License - see the [ LICENSE] ( LICENSE ) file for details.
190257
191258## Authors
192259
193- - ** Tim Joseph** - * Creator and Lead Developer * - [ mctigger] ( https://github.com/mctigger )
260+ - ** Tim Joseph** - [ mctigger] ( https://github.com/mctigger )
194261
195262## Contact and Support
196263
0 commit comments