Skip to content

Commit 5d1af5a

Browse files
author
Vincent Moens
committed
[Feature] Structured dtype
ghstack-source-id: 46e12ff Pull Request resolved: #1195
1 parent 9a25b88 commit 5d1af5a

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

tensordict/_torch_func.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,3 +728,20 @@ def _stack_uninit_params(list_of_params, dim=0, out=None):
728728
)
729729
out.batch_size = torch.Size([len(list_of_params)])
730730
return out
731+
732+
def implements_for_tdtype(torch_function: Callable) -> Callable[[Callable], Callable]:
733+
"""Register a torch function override for TensorDict."""
734+
735+
from tensordict.dtype import TDTYPE_HANDLED_FUNCTIONS
736+
737+
@functools.wraps(torch_function)
738+
def decorator(func: Callable) -> Callable:
739+
TDTYPE_HANDLED_FUNCTIONS[torch_function] = func
740+
return func
741+
742+
return decorator
743+
744+
@implements_for_tdtype(torch.Tensor.view)
745+
def view(tensor: torch.tensor, dtype: Any) -> TensorDictBase:
746+
from tensordict.dtype import StructDtype
747+
return StructDtype.view(tensor, dtype)

tensordict/dtype.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from collections import deque
6+
import orjson as json
7+
from typing import Callable, Any
8+
9+
10+
TDTYPE_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
11+
12+
class StructDtype:
13+
# def __new__(cls, map=None):
14+
# if isinstance(map, StructDtype):
15+
# return map
16+
# return super().__new__(cls)
17+
def __init__(self, map=None):
18+
if map is None:
19+
map = {}
20+
assert isinstance(map, dict)
21+
self._maps = map
22+
23+
@classmethod
24+
def from_td(cls, data: "TensorDictBase"):
25+
from tensordict.base import _is_tensor_collection
26+
self = cls()
27+
map = self._maps
28+
stack = deque()
29+
stack.append((self, data))
30+
while len(stack):
31+
sdtype, local_data = stack.popleft()
32+
map = sdtype._maps
33+
# TODO: handle lazy stacks here
34+
for k, v in local_data.items():
35+
cls = type(v)
36+
if _is_tensor_collection(cls):
37+
# TODO: handle different dtypes here
38+
# TODO: handle LazyStacks here
39+
newmap = map[k] = StructDtype({})
40+
stack.append((newmap, v))
41+
else:
42+
map[k] = {
43+
"shape": v.shape,
44+
"dtype": v.dtype,
45+
}
46+
return self
47+
48+
def items(self, include_nested: bool=False, leaves_only: bool=False):
49+
stack = deque()
50+
stack.append(self)
51+
while len(stack):
52+
node = stack.popleft()
53+
for k, v in node._maps.items():
54+
if isinstance(v, StructDtype):
55+
if include_nested:
56+
stack.append(v)
57+
if not leaves_only:
58+
yield (k, v)
59+
else:
60+
yield k, v
61+
62+
def values(self, include_nested: bool=False, leaves_only: bool=False):
63+
yield from (_, v in self.items(include_nested=include_nested, leaves_only=leaves_only))
64+
65+
def keys(self, include_nested: bool=False, leaves_only: bool=False):
66+
yield from (k, _ in self.items(include_nested=include_nested, leaves_only=leaves_only))
67+
68+
# def json(self):
69+
# return json.dumps(metadata_dict)
70+
71+
@classmethod
72+
def __torch_function__(
73+
cls,
74+
func: Callable,
75+
types: tuple[type, ...],
76+
args: tuple[Any, ...] = (),
77+
kwargs: dict[str, Any] | None = None,
78+
) -> Callable:
79+
if kwargs is None:
80+
kwargs = {}
81+
if func not in TDTYPE_HANDLED_FUNCTIONS:
82+
return NotImplemented
83+
return TDTYPE_HANDLED_FUNCTIONS[func](*args, **kwargs)
84+
85+
86+
@classmethod
87+
def view(cls, tensor, dtype):
88+
from tensordict import TensorDict
89+
ns = []
90+
shapes = []
91+
dts = []
92+
keys = []
93+
stack = deque()
94+
stack.append((dtype.items(), ()))
95+
tensor_itemsize = tensor.dtype.itemsize
96+
while len(stack):
97+
items, prefix = stack.popleft()
98+
for k, dt in items:
99+
currentk = prefix + (k,)
100+
if isinstance(dt, StructDtype):
101+
stack.append((dt.items(), currentk))
102+
continue
103+
assert currentk not in keys, (currentk, keys)
104+
keys.append(currentk)
105+
s = dt["shape"]
106+
dt = dt["dtype"]
107+
shapes.append(s)
108+
dts.append(dt)
109+
nelts = (dt.itemsize * s.numel()) // tensor_itemsize
110+
ns.append(nelts)
111+
112+
return TensorDict({k: v.view(dt).view(shape) for k, v, dt, shape in zip(keys, tensor.split(ns), dts, shapes, strict=True)})

0 commit comments

Comments
 (0)