You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
> **⚠️ Academic Research Project**: This project exists solely for academic purposes to explore and learn PyTorch internals. For production use, please use the official, well-maintained [**torch/tensordict**](https://github.com/pytorch/tensordict) library.
13
-
14
13
Tensor Container provides efficient, type-safe tensor container implementations for PyTorch workflows. It includes PyTree integration and torch.compile optimization for batched tensor operations.
15
14
16
-
The library includes tensor containers, probabilistic distributions, and batch/event dimension semantics for machine learning workflows.
15
+
The library includes tensor containers (dict, dataclass) and distributions (torch.distributions equivalent).
17
16
18
17
## What is TensorContainer?
19
18
20
-
TensorContainer transforms how you work with structured tensor data in PyTorch by providing **tensor-like operations for entire data structures**. Instead of manually managing individual tensors across devices, batch dimensions, and nested hierarchies, TensorContainer lets you treat complex data as unified entities that behave just like regular tensors.
21
-
22
-
### 🚀 **Unified Operations Across Data Types**
23
-
24
-
Apply tensor operations like `view()`, `permute()`, `detach()`, and device transfers to entire data structures—no matter how complex:
25
-
26
-
```python
27
-
# Single operation transforms entire distribution
28
-
distribution = distribution.view(2, 3, 4).permute(1, 0, 2).detach()
29
-
30
-
# Works seamlessly across TensorDict, TensorDataClass, and TensorDistribution
31
-
data = data.to('cuda').reshape(batch_size, -1).clone()
32
-
```
33
-
34
-
### 🔄 **Drop-in Compatibility with PyTorch**
19
+
TensorContainer transforms how you work with structured tensor data in PyTorch by providing **tensor-like operations for entire data structures**. Instead of manually managing individual tensors, TensorContainer lets you treat complex data as unified entities that behave just like regular tensors.
35
20
36
-
TensorContainer integrates seamlessly with existing PyTorch workflows:
37
-
-**torch.distributions compatibility**: TensorDistribution is API-compatible with `torch.distributions` while adding tensor-like operations
38
-
-**PyTree support**: All containers work with `torch.utils._pytree` operations and `torch.compile`
39
-
-**Zero learning curve**: If you know PyTorch tensors, you already know TensorContainer
21
+
### **Core Benefits**
40
22
41
-
### ⚡ **Eliminates Boilerplate Code**
23
+
-**Unified Operations**: Apply tensor operations like `view()`, `permute()`, `detach()`, and device transfers to entire data structures
24
+
-**Drop-in Compatibility**: Seamless integration with existing PyTorch workflows and `torch.compile`
25
+
-**Zero Boilerplate**: Eliminate manual parameter handling and type-specific operations
26
+
-**Type Safety**: Full IDE support with static typing and autocomplete
data = data.view(2, 3, 4).permute(1, 0, 2).to('cuda').detach()
59
37
```
60
38
61
-
### 🏗️ **Structured Data Made Simple**
39
+
### **Key Features**
62
40
63
-
Handle complex, nested tensor structures with the same ease as single tensors:
64
-
-**Batch semantics**: Consistent shape handling across all nested tensors
65
-
-**Device management**: Move entire structures between CPU/GPU with single operations
66
-
-**Shape validation**: Automatic verification of tensor compatibility
67
-
-**Type safety**: Full IDE support with static typing and autocomplete
41
+
-**⚡ JIT Compilation**: Designed for `torch.compile` with `fullgraph=True`, minimizing graph breaks and maximizing performance
42
+
-**📐 Batch/Event Semantics**: Clear distinction between batch dimensions (consistent across tensors) and event dimensions (tensor-specific)
43
+
-**🔄 Device Management**: Move entire structures between CPU/GPU with single operations and flexible device compatibility
44
+
-**🔒 Type Safety**: Full IDE support with static typing and autocomplete
45
+
-**🏗️ Multiple Container Types**: Three specialized containers for different use cases:
46
+
-`TensorDict` for dynamic, dictionary-style data collections
47
+
-`TensorDataClass` for type-safe, dataclass-based structures
48
+
-`TensorDistribution` for probabilistic modeling with 40+ probability distributions
49
+
-**🔧 Advanced Operations**: Full PyTorch tensor operations support including `view()`, `permute()`, `stack()`, `cat()`, and more
50
+
-**🎯 Advanced Indexing**: Complete PyTorch indexing semantics with boolean masks, tensor indices, and ellipsis support
51
+
-**📊 Shape Validation**: Automatic verification of tensor compatibility with detailed error messages
52
+
-**🌳 Nested Structure Support**: Create nested structure with different TensorContainers
68
53
69
-
TensorContainer doesn't just store your data—it makes working with structured tensors as intuitive as working with individual tensors, while maintaining full compatibility with the PyTorch ecosystem you already know.
70
54
71
55
## Table of Contents
72
56
57
+
-[What is TensorContainer?](#what-is-tensorcontainer)
TensorContainer transforms how you work with structured tensor data. Instead of managing individual tensors, you can treat entire data structures as unified entities that behave like regular tensors.
86
+
87
+
```python
88
+
# Single operation transforms entire structure
89
+
data = data.view(2, 3, 4).permute(1, 0, 2).to('cuda').detach()
90
+
```
91
+
92
+
### 1. TensorDict: Dynamic Data Collections
93
+
94
+
Perfect for reinforcement learning data and dynamic collections:
108
95
109
96
```python
110
97
import torch
111
98
from tensorcontainer import TensorDict
112
99
113
-
# Create a TensorDict with batch semantics
100
+
# Create a container for RL training data
114
101
data = TensorDict({
115
102
'observations': torch.randn(32, 128),
116
103
'actions': torch.randn(32, 4),
117
104
'rewards': torch.randn(32, 1)
118
-
}, shape=(32,), device='cpu')
105
+
}, shape=(32,))
119
106
120
-
# Dictionary-like access
107
+
# Dictionary-like access with tensor operations
121
108
obs = data['observations']
122
-
data['new_field'] = torch.zeros(32, 10)
109
+
data['advantages'] = torch.randn(32, 1) # Add new fields dynamically
0 commit comments