You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
- Updated main README.md with improved structure, removed redundant sections, and enhanced clarity
- Added comprehensive user guide in docs/user_guide/README.md with detailed examples and concepts
- Reorganized examples documentation with new overview and navigation
- Removed outdated examples/tensor_dataclass/README.md
> **⚠️ Academic Research Project**: This project exists solely for academic purposes to explore and learn PyTorch internals. For production use, please use the official, well-maintained [**torch/tensordict**](https://github.com/pytorch/tensordict) library.
13
-
14
12
Tensor Container provides efficient, type-safe tensor container implementations for PyTorch workflows. It includes PyTree integration and torch.compile optimization for batched tensor operations.
15
13
16
-
The library includes tensor containers, probabilistic distributions, and batch/event dimension semantics for machine learning workflows.
14
+
The library includes tensor containers (dict, dataclass) and distributions (torch.distributions equivalent).
17
15
18
16
## What is TensorContainer?
19
17
20
-
TensorContainer transforms how you work with structured tensor data in PyTorch by providing **tensor-like operations for entire data structures**. Instead of manually managing individual tensors across devices, batch dimensions, and nested hierarchies, TensorContainer lets you treat complex data as unified entities that behave just like regular tensors.
21
-
22
-
### 🚀 **Unified Operations Across Data Types**
23
-
24
-
Apply tensor operations like `view()`, `permute()`, `detach()`, and device transfers to entire data structures—no matter how complex:
25
-
26
-
```python
27
-
# Single operation transforms entire distribution
28
-
distribution = distribution.view(2, 3, 4).permute(1, 0, 2).detach()
29
-
30
-
# Works seamlessly across TensorDict, TensorDataClass, and TensorDistribution
31
-
data = data.to('cuda').reshape(batch_size, -1).clone()
32
-
```
33
-
34
-
### 🔄 **Drop-in Compatibility with PyTorch**
18
+
TensorContainer transforms how you work with structured tensor data in PyTorch by providing **tensor-like operations for entire data structures**. Instead of manually managing individual tensors, TensorContainer lets you treat complex data as unified entities that behave just like regular tensors.
35
19
36
-
TensorContainer integrates seamlessly with existing PyTorch workflows:
37
-
-**torch.distributions compatibility**: TensorDistribution is API-compatible with `torch.distributions` while adding tensor-like operations
38
-
-**PyTree support**: All containers work with `torch.utils._pytree` operations and `torch.compile`
39
-
-**Zero learning curve**: If you know PyTorch tensors, you already know TensorContainer
20
+
### **Core Benefits**
40
21
41
-
### ⚡ **Eliminates Boilerplate Code**
22
+
-**Unified Operations**: Apply tensor operations like `view()`, `permute()`, `detach()`, and device transfers to entire data structures
23
+
-**Drop-in Compatibility**: Seamless integration with existing PyTorch workflows and `torch.compile`
24
+
-**Zero Boilerplate**: Eliminate manual parameter handling and type-specific operations
25
+
-**Type Safety**: Full IDE support with static typing and autocomplete
data = data.view(2, 3, 4).permute(1, 0, 2).to('cuda').detach()
59
36
```
60
37
61
-
### 🏗️ **Structured Data Made Simple**
38
+
### **Key Features**
62
39
63
-
Handle complex, nested tensor structures with the same ease as single tensors:
64
-
-**Batch semantics**: Consistent shape handling across all nested tensors
65
-
-**Device management**: Move entire structures between CPU/GPU with single operations
66
-
-**Shape validation**: Automatic verification of tensor compatibility
67
-
-**Type safety**: Full IDE support with static typing and autocomplete
40
+
-**⚡ JIT Compilation**: Designed for `torch.compile` with `fullgraph=True`, minimizing graph breaks and maximizing performance
41
+
-**📐 Batch/Event Semantics**: Clear distinction between batch dimensions (consistent across tensors) and event dimensions (tensor-specific)
42
+
-**🔄 Device Management**: Move entire structures between CPU/GPU with single operations and flexible device compatibility
43
+
-**🔒 Type Safety**: Full IDE support with static typing and autocomplete
44
+
-**🏗️ Multiple Container Types**: Three specialized containers for different use cases:
45
+
-`TensorDict` for dynamic, dictionary-style data collections
46
+
-`TensorDataClass` for type-safe, dataclass-based structures
47
+
-`TensorDistribution` for probabilistic modeling with 40+ probability distributions
48
+
-**🔧 Advanced Operations**: Full PyTorch tensor operations support including `view()`, `permute()`, `stack()`, `cat()`, and more
49
+
-**🎯 Advanced Indexing**: Complete PyTorch indexing semantics with boolean masks, tensor indices, and ellipsis support
50
+
-**📊 Shape Validation**: Automatic verification of tensor compatibility with detailed error messages
51
+
-**🌳 Nested Structure Support**: Create nested structure with different TensorContainers
68
52
69
-
TensorContainer doesn't just store your data—it makes working with structured tensors as intuitive as working with individual tensors, while maintaining full compatibility with the PyTorch ecosystem you already know.
70
53
71
54
## Table of Contents
72
55
56
+
-[What is TensorContainer?](#what-is-tensorcontainer)
TensorContainer transforms how you work with structured tensor data. Instead of managing individual tensors, you can treat entire data structures as unified entities that behave like regular tensors.
85
+
86
+
```python
87
+
# Single operation transforms entire structure
88
+
data = data.view(2, 3, 4).permute(1, 0, 2).to('cuda').detach()
89
+
```
90
+
91
+
### 1. TensorDict: Dynamic Data Collections
92
+
93
+
Perfect for reinforcement learning data and dynamic collections:
108
94
109
95
```python
110
96
import torch
111
97
from tensorcontainer import TensorDict
112
98
113
-
# Create a TensorDict with batch semantics
99
+
# Create a container for RL training data
114
100
data = TensorDict({
115
101
'observations': torch.randn(32, 128),
116
102
'actions': torch.randn(32, 4),
117
103
'rewards': torch.randn(32, 1)
118
-
}, shape=(32,), device='cpu')
104
+
}, shape=(32,))
119
105
120
-
# Dictionary-like access
106
+
# Dictionary-like access with tensor operations
121
107
obs = data['observations']
122
-
data['new_field'] = torch.zeros(32, 10)
108
+
data['advantages'] = torch.randn(32, 1) # Add new fields dynamically
0 commit comments