Skip to content

Commit f429ec8

Browse files
authored
Merge pull request #1 from mctigger/CI-init
Create python-package.yml, adds Python 3.9 compatibility and no-cuda pytest compatibility
2 parents 161e3fd + 360d5a0 commit f429ec8

28 files changed

+325
-100
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: Python package
2+
3+
on:
4+
push:
5+
branches: [ "main" ]
6+
pull_request:
7+
branches: [ "main" ]
8+
9+
jobs:
10+
build:
11+
12+
runs-on: ubuntu-latest
13+
strategy:
14+
fail-fast: false
15+
matrix:
16+
python-version: ["3.9", "3.10", "3.11"]
17+
18+
steps:
19+
- uses: actions/checkout@v4
20+
- name: Set up Python ${{ matrix.python-version }}
21+
uses: actions/setup-python@v5
22+
with:
23+
python-version: ${{ matrix.python-version }}
24+
- name: Install dependencies
25+
run: |
26+
python -m pip install --upgrade pip
27+
python -m pip install .[dev]
28+
- name: Lint and Format with Ruff
29+
run: |
30+
# Run ruff to check for linting errors
31+
ruff check .
32+
# Run ruff to check for formatting issues (like black --check)
33+
ruff format --check .
34+
- name: Test with pytest
35+
run: |
36+
pytest

README.md

Lines changed: 87 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# Tensor Container
22

3-
*Modern tensor containers for PyTorch with PyTree compatibility and torch.compile optimization*
3+
*Tensor containers for PyTorch with PyTree compatibility and torch.compile optimization*
44

5-
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
5+
[![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/)
66
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
77
[![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-red.svg)](https://pytorch.org/)
88

99
> **⚠️ Academic Research Project**: This project exists solely for academic purposes to explore and learn PyTorch internals. For production use, please use the official, well-maintained [**torch/tensordict**](https://github.com/pytorch/tensordict) library.
1010
11-
Tensor Container provides efficient, type-safe tensor container implementations designed for modern PyTorch workflows. Built from the ground up with PyTree integration and torch.compile optimization, it enables seamless batched tensor operations with minimal overhead and maximum performance.
11+
Tensor Container provides efficient, type-safe tensor container implementations for PyTorch workflows. It includes PyTree integration and torch.compile optimization for batched tensor operations.
12+
13+
The library includes tensor containers, probabilistic distributions, and batch/event dimension semantics for machine learning workflows.
1214

1315
## Table of Contents
1416

@@ -18,6 +20,7 @@ Tensor Container provides efficient, type-safe tensor container implementations
1820
- [API Overview](#api-overview)
1921
- [torch.compile Compatibility](#torchcompile-compatibility)
2022
- [Contributing](#contributing)
23+
- [Documentation](#documentation)
2124
- [License](#license)
2225
- [Authors](#authors)
2326
- [Contact and Support](#contact-and-support)
@@ -40,7 +43,7 @@ pip install -e .[dev]
4043

4144
### Requirements
4245

43-
- Python 3.8+
46+
- Python 3.9+
4447
- PyTorch 2.0+
4548

4649
## Quick Start
@@ -91,6 +94,39 @@ obs = data.observations
9194
data.actions = torch.randn(32, 8) # Type-checked assignment
9295
```
9396

97+
### TensorDistribution: Probabilistic Containers
98+
99+
```python
100+
import torch
101+
from tensorcontainer import TensorDistribution
102+
103+
# Built-in distribution types
104+
from tensorcontainer.tensor_distribution import (
105+
TensorNormal, TensorBernoulli, TensorCategorical,
106+
TensorTruncatedNormal, TensorTanhNormal
107+
)
108+
109+
# Create probabilistic tensor containers
110+
normal_dist = TensorNormal(
111+
loc=torch.zeros(32, 4),
112+
scale=torch.ones(32, 4),
113+
shape=(32,),
114+
device='cpu'
115+
)
116+
117+
# Sample and compute probabilities
118+
samples = normal_dist.sample() # Shape: (32, 4)
119+
log_probs = normal_dist.log_prob(samples)
120+
entropy = normal_dist.entropy()
121+
122+
# Categorical distributions for discrete actions
123+
categorical = TensorCategorical(
124+
logits=torch.randn(32, 6), # 6 possible actions
125+
shape=(32,),
126+
device='cpu'
127+
)
128+
```
129+
94130
### PyTree Operations
95131

96132
```python
@@ -106,21 +142,23 @@ combined = pytree.tree_map(lambda x, y: x + y, data1, data2)
106142

107143
## Features
108144

109-
- **🔥 torch.compile Optimized**: Built for maximum performance with PyTorch's JIT compiler
110-
- **🌳 Native PyTree Support**: Seamless integration with `torch.utils._pytree` for tree operations
111-
- **⚡ Zero-Copy Operations**: Efficient tensor sharing and manipulation without unnecessary copies
112-
- **🎯 Type Safety**: Full static typing support with IDE autocomplete and type checking
113-
- **📊 Batch Semantics**: Consistent batch/event dimension handling across all operations
114-
- **🔍 Shape Validation**: Automatic validation of tensor shapes and device consistency
115-
- **🏗️ Flexible Architecture**: Multiple container types for different use cases
116-
- **🧪 Comprehensive Testing**: Extensive test suite with compile compatibility verification
145+
- **torch.compile Optimized**: Compatible with PyTorch's JIT compiler
146+
- **PyTree Support**: Integration with `torch.utils._pytree` for tree operations
147+
- **Zero-Copy Operations**: Efficient tensor sharing and manipulation
148+
- **Type Safety**: Static typing support with IDE autocomplete and type checking
149+
- **Batch Semantics**: Consistent batch/event dimension handling
150+
- **Shape Validation**: Automatic validation of tensor shapes and device consistency
151+
- **Multiple Container Types**: Different container types for different use cases
152+
- **Probabilistic Support**: Distribution containers for probabilistic modeling
153+
- **Comprehensive Testing**: Extensive test suite with compile compatibility verification
154+
- **Memory Efficient**: Optimized memory usage with slots-based dataclasses
117155

118156
## API Overview
119157

120158
### Core Components
121159

122-
- **`TensorContainer`**: Base class providing core tensor manipulation operations
123-
- **`TensorDict`**: Dictionary-like container for dynamic tensor collections
160+
- **`TensorContainer`**: Base class providing core tensor manipulation operations with batch/event dimension semantics
161+
- **`TensorDict`**: Dictionary-like container for dynamic tensor collections with nested structure support
124162
- **`TensorDataClass`**: DataClass-based container for static, typed tensor structures
125163
- **`TensorDistribution`**: Distribution wrapper for probabilistic tensor operations
126164

@@ -130,25 +168,37 @@ combined = pytree.tree_map(lambda x, y: x + y, data1, data2)
130168
- **Event Dimensions**: Trailing dimensions beyond batch shape, can vary per tensor
131169
- **PyTree Integration**: All containers are registered PyTree nodes for seamless tree operations
132170
- **Device Consistency**: Automatic validation ensures all tensors reside on compatible devices
171+
- **Unsafe Construction**: Context manager for performance-critical scenarios with validation bypass
133172

134173
## torch.compile Compatibility
135174

136-
Tensor Container is designed from the ground up for `torch.compile` compatibility:
175+
Tensor Container is designed for `torch.compile` compatibility:
137176

138177
```python
139178
@torch.compile
140179
def process_batch(data: TensorDict) -> TensorDict:
141-
return data.apply(lambda x: torch.relu(x))
180+
# PyTree operations compile efficiently
181+
return TensorContainer._tree_map(lambda x: torch.relu(x), data)
142182

143-
# Compiles efficiently with minimal graph breaks
183+
@torch.compile
184+
def sample_and_score(dist: TensorNormal, actions: torch.Tensor) -> torch.Tensor:
185+
# Distribution operations are compile-safe
186+
return dist.log_prob(actions)
187+
188+
# All operations compile efficiently with minimal graph breaks
144189
compiled_result = process_batch(tensor_dict)
190+
log_probs = sample_and_score(normal_dist, action_tensor)
145191
```
146192

147-
Our testing framework includes comprehensive compile compatibility verification to ensure all operations work efficiently under JIT compilation.
193+
The testing framework includes compile compatibility verification to ensure operations work efficiently under JIT compilation, including:
194+
- Graph break detection and minimization
195+
- Recompilation tracking
196+
- Memory leak prevention
197+
- Performance benchmarking
148198

149199
## Contributing
150200

151-
We welcome contributions! Tensor Container is a learning project for exploring PyTorch internals and tensor container implementations.
201+
Contributions are welcome! Tensor Container is a learning project for exploring PyTorch internals and tensor container implementations.
152202

153203
### Development Setup
154204

@@ -168,13 +218,21 @@ pytest --strict-markers --cov=src
168218
# Run specific test modules
169219
pytest tests/tensor_dict/test_compile.py
170220
pytest tests/tensor_dataclass/
221+
pytest tests/tensor_distribution/
222+
223+
# Run compile-specific tests
224+
pytest tests/tensor_dict/test_graph_breaks.py
225+
pytest tests/tensor_dict/test_recompilations.py
171226
```
172227

173228
### Development Guidelines
174229

175230
- All new features must maintain `torch.compile` compatibility
176231
- Comprehensive tests required, including compile compatibility verification
177232
- Follow existing code patterns and typing conventions
233+
- Distribution implementations must support KL divergence registration
234+
- Memory efficiency considerations for large-scale tensor operations
235+
- Unsafe construction patterns for performance-critical paths
178236

179237
### Contribution Process
180238

@@ -184,13 +242,22 @@ pytest tests/tensor_dataclass/
184242
4. Ensure all tests pass and maintain coverage
185243
5. Submit a pull request with a clear description
186244

245+
## Documentation
246+
247+
The project includes documentation:
248+
249+
- **[`docs/compatibility.md`](docs/compatibility.md)**: Python version compatibility guide and best practices
250+
- **[`docs/testing.md`](docs/testing.md)**: Testing philosophy, standards, and guidelines
251+
- **Source Code Documentation**: Extensive docstrings and type annotations throughout the codebase
252+
- **Test Coverage**: 643+ tests covering all major functionality with 86% code coverage
253+
187254
## License
188255

189256
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
190257

191258
## Authors
192259

193-
- **Tim Joseph** - *Creator and Lead Developer* - [mctigger](https://github.com/mctigger)
260+
- **Tim Joseph** - [mctigger](https://github.com/mctigger)
194261

195262
## Contact and Support
196263

docs/compatibility.md

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Python Compatibility Guide
2+
3+
This document outlines compatibility considerations and solutions for supporting different Python versions in the tensorcontainer project.
4+
5+
## Supported Python Versions
6+
7+
**tensorcontainer requires Python 3.9 or higher.**
8+
9+
## Compatibility Changes
10+
11+
### Union Types and isinstance() Checks
12+
13+
Union types can now be handled using `typing.get_args()`:
14+
15+
```python
16+
from typing import Union, get_args
17+
import torch
18+
from tensorcontainer import TensorContainer
19+
20+
TDCompatible = Union[torch.Tensor, TensorContainer]
21+
22+
# Use get_args to extract types from Union:
23+
if isinstance(val, get_args(TDCompatible)):
24+
pass
25+
```
26+
27+
### Type Annotations
28+
29+
**Note**: The `|` operator for union types was introduced in Python 3.10:
30+
31+
```python
32+
# Python 3.10+ only:
33+
def func(x: int | str) -> None:
34+
pass
35+
```
36+
37+
For broader compatibility with Python 3.9+, use `Union` from typing:
38+
39+
```python
40+
from typing import Union
41+
42+
# Compatible with Python 3.9+:
43+
def func(x: Union[int, str]) -> None:
44+
pass
45+
```
46+
47+
## General Compatibility Tips
48+
49+
### Import Compatibility
50+
51+
Use `from __future__ import annotations` at the top of files to enable forward references and improve compatibility:
52+
53+
```python
54+
from __future__ import annotations
55+
56+
# This allows using string annotations that are evaluated later
57+
def func(x: 'SomeClass') -> 'SomeClass':
58+
pass
59+
```
60+
61+
### Version-Specific Features
62+
63+
When using features that are only available in newer Python versions, use version checks:
64+
65+
```python
66+
import sys
67+
68+
if sys.version_info >= (3, 10):
69+
# Use Python 3.10+ features
70+
pass
71+
else:
72+
# Fallback for Python 3.9
73+
pass
74+
```
75+

pyproject.toml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@ dependencies = [
1212
"torch"
1313
]
1414
readme = "README.md"
15-
requires-python = ">=3.8"
15+
requires-python = ">=3.9"
1616
keywords = ["deep learning", "tensordict", "pytorch"]
1717
urls = {Homepage = "https://github.com/mctigger/tensor-container"}
1818
classifiers = [
1919
"License :: OSI Approved :: MIT License",
2020
"Programming Language :: Python :: 3",
21-
"Programming Language :: Python :: 3.8",
2221
"Programming Language :: Python :: 3.9",
2322
"Programming Language :: Python :: 3.10",
2423
"Programming Language :: Python :: 3.11",
@@ -28,10 +27,9 @@ classifiers = [
2827
[project.optional-dependencies]
2928
dev = [
3029
"pytest>=7.2,<8.0",
31-
# you can add extra test-related tools here:
3230
"pytest-cov",
33-
# "hypothesis",
34-
# "flake8",
31+
"ruff"
32+
3533
]
3634

3735

src/tensorcontainer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from .tensor_dict import TensorDict
33
from .tensor_distribution import TensorDistribution
44

5-
__all__ = ["TensorDataClass", "TensorDict", "TensorDistribution"]
5+
__all__ = ["TensorDataClass", "TensorDict", "TensorDistribution"]

src/tensorcontainer/tensor_container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
Optional,
1414
Tuple,
1515
Type,
16-
TypeAlias,
1716
Union,
1817
)
18+
from typing_extensions import TypeAlias
1919

2020
import torch
2121

0 commit comments

Comments
 (0)