Skip to content

Commit e696708

Browse files
author
Vincent Moens
committed
[Feature] TD+NJT to(device) support
ghstack-source-id: 5f84ebc Pull Request resolved: #1022
1 parent 7e45bcc commit e696708

File tree

4 files changed

+190
-52
lines changed

4 files changed

+190
-52
lines changed

tensordict/_reductions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,19 @@ def from_metadata(metadata=metadata, prefix=None):
9999
value = value[: local_shape.numel()]
100100
value = value.view(local_shape)
101101
if key.startswith("<NJT>"):
102+
raise RuntimeError
103+
elif key.startswith("<NJT_VALUES>"):
102104
nested_values = value
103105
nested_lengths = None
104106
continue
105107
elif key.startswith("<NJT_LENGTHS>"):
106108
nested_lengths = value
107109
continue
108110
elif key.startswith("<NJT_OFFSETS>"):
111+
from torch.nested._internal.nested_tensor import NestedTensor
112+
109113
offsets = value
110-
value = torch.nested.nested_tensor_from_jagged(
114+
value = NestedTensor(
111115
nested_values, offsets=offsets, lengths=nested_lengths
112116
)
113117
key = key.replace("<NJT_OFFSETS>", "")

tensordict/base.py

Lines changed: 88 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from collections.abc import MutableMapping
2121

2222
from concurrent.futures import Future, ThreadPoolExecutor, wait
23-
from copy import copy, deepcopy
23+
from copy import copy
2424
from functools import partial, wraps
2525
from pathlib import Path
2626
from textwrap import indent
@@ -66,6 +66,7 @@
6666
_prefix_last_key,
6767
_proc_init,
6868
_prune_selected_keys,
69+
_rebuild_njt_from_njt,
6970
_set_max_batch_size,
7071
_shape,
7172
_split_tensordict,
@@ -3591,7 +3592,7 @@ def assign(
35913592
if getattr(value, "is_nested", False):
35923593
if value.layout is torch.jagged:
35933594
# Get the values
3594-
values = value.values()
3595+
values = value._values
35953596
shape = [v if isinstance(v, int) else -1 for v in values.shape]
35963597
# Get the offsets
35973598
offsets = value._offsets
@@ -3602,10 +3603,14 @@ def assign(
36023603
# We will rely on the fact that the writing order is preserved in python dict
36033604
# (since python 3.7). Later, we will read the NJT then the NJT offset in that order
36043605
# to do the allocation.
3605-
flat_key_values[_prefix_last_key(total_key, "<NJT>")] = values
3606+
flat_key_values[_prefix_last_key(total_key, "<NJT>")] = value
3607+
flat_size.append(0)
3608+
flat_key_values[_prefix_last_key(total_key, "<NJT_VALUES>")] = (
3609+
values
3610+
)
36063611
add_single_value(
36073612
values,
3608-
_prefix_last_key(key, "<NJT>"),
3613+
_prefix_last_key(key, "<NJT_VALUES>"),
36093614
metadata_dict,
36103615
values.dtype,
36113616
shape,
@@ -3811,12 +3816,14 @@ def assign(
38113816
start,
38123817
stop,
38133818
njts,
3814-
njts_offsets,
3815-
njts_lengths,
38163819
storage=storage,
38173820
non_blocking=non_blocking,
38183821
):
3822+
"""Reads a slice of the storage and assigns the resulting tensor in flat_dict."""
38193823
# v may need padding
3824+
if k[-1].startswith("<NJT>"):
3825+
njts[k] = v
3826+
return
38203827
v_pad = v.view(-1).view(torch.uint8)
38213828
exp_length = stop - start
38223829
pad = exp_length - v_pad.numel()
@@ -3830,17 +3837,9 @@ def assign(
38303837
if pad:
38313838
new_v = new_v[: v.numel()]
38323839
new_v = new_v.view(shape)
3833-
if k[-1].startswith("<NJT>"):
3834-
njts[k] = new_v
3835-
elif k[-1].startswith("<NJT_LENGTHS>"):
3836-
njts_lengths[k] = new_v
3837-
elif k[-1].startswith("<NJT_OFFSETS>"):
3838-
njts_offsets[k] = new_v
38393840
flat_dict[k] = new_v
38403841

38413842
njts = {}
3842-
njts_offsets = {}
3843-
njts_lengths = {}
38443843
if num_threads > 1:
38453844
executor = ThreadPoolExecutor(num_threads)
38463845
r = []
@@ -3853,8 +3852,6 @@ def assign(
38533852
start=offsets[i],
38543853
stop=offsets[i + 1],
38553854
njts=njts,
3856-
njts_offsets=njts_offsets,
3857-
njts_lengths=njts_lengths,
38583855
)
38593856
)
38603857
if not return_early:
@@ -3872,25 +3869,25 @@ def assign(
38723869
start=offsets[i],
38733870
stop=offsets[i + 1],
38743871
njts=njts,
3875-
njts_offsets=njts_offsets,
3876-
njts_lengths=njts_lengths,
38773872
)
3878-
for njt_key, njt_val in njts.items():
3873+
for njt_key, njt in njts.items():
3874+
newkey = njt_key[:-1] + (njt_key[-1].replace("<NJT>", ""),)
3875+
njt_key_values = njt_key[:-1] + (
3876+
njt_key[-1].replace("<NJT>", "<NJT_VALUES>"),
3877+
)
38793878
njt_key_offset = njt_key[:-1] + (
38803879
njt_key[-1].replace("<NJT>", "<NJT_OFFSETS>"),
38813880
)
38823881
njt_key_lengths = njt_key[:-1] + (
38833882
njt_key[-1].replace("<NJT>", "<NJT_LENGTHS>"),
38843883
)
3885-
val = torch.nested.nested_tensor_from_jagged(
3886-
njt_val,
3887-
offsets=flat_dict[njt_key_offset],
3888-
lengths=flat_dict.get(njt_key_lengths),
3884+
val = _rebuild_njt_from_njt(
3885+
njt,
3886+
values=flat_dict.pop(njt_key_values),
3887+
offsets=flat_dict.pop(njt_key_offset),
3888+
lengths=flat_dict.pop(njt_key_lengths, None),
38893889
)
38903890
del flat_dict[njt_key]
3891-
del flat_dict[njt_key_offset]
3892-
flat_dict.pop(njt_key_lengths, None)
3893-
newkey = njt_key[:-1] + (njt_key[-1].replace("<NJT>", ""),)
38943891
flat_dict[newkey] = val
38953892

38963893
if non_blocking and device.type != "cuda":
@@ -3910,6 +3907,8 @@ def _view_and_pad(tensor):
39103907

39113908
items = []
39123909
for v in flat_dict.values():
3910+
if v.is_nested:
3911+
continue
39133912
if v.device != storage.device:
39143913
v = v.to(storage.device, non_blocking=non_blocking)
39153914
stride = v.stride()
@@ -3928,9 +3927,13 @@ def _view_and_pad(tensor):
39283927
flat_dict[k] = view_old_as_new(v, oldv)
39293928
elif k[-1].startswith("<NJT>"):
39303929
# NJT/NT always comes before offsets/shapes
3931-
_nested_values = view_old_as_new(v, oldv)
3930+
nt = oldv
3931+
assert not v.numel()
39323932
nt_lengths = None
39333933
del flat_dict[k]
3934+
elif k[-1].startswith("<NJT_VALUES>"):
3935+
nt_vaues = view_old_as_new(v, oldv)
3936+
del flat_dict[k]
39343937
elif k[-1].startswith("<NJT_LENGTHS>"):
39353938
nt_lengths = view_old_as_new(v, oldv)
39363939
del flat_dict[k]
@@ -3939,15 +3942,16 @@ def _view_and_pad(tensor):
39393942
nt_offsets = view_old_as_new(v, oldv)
39403943
del flat_dict[k]
39413944

3942-
flat_dict[newk] = torch.nested.nested_tensor_from_jagged(
3943-
_nested_values,
3944-
offsets=nt_offsets,
3945-
lengths=nt_lengths,
3945+
val = _rebuild_njt_from_njt(
3946+
nt, values=nt_vaues, offsets=nt_offsets, lengths=nt_lengths
39463947
)
3948+
3949+
flat_dict[newk] = val
3950+
39473951
# delete the nested value to make sure that if there was an
39483952
# ordering mismatch we wouldn't be looking at the value key of
39493953
# another nested tensor.
3950-
del _nested_values
3954+
del nt, nt_vaues, nt_offsets, nt_lengths
39513955
else:
39523956
flat_dict[k] = view_old_as_new(v, oldv)
39533957

@@ -10459,9 +10463,52 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
1045910463
untyped_storage = storage_cast.untyped_storage()
1046010464

1046110465
def set_(x):
10466+
if x.is_nested:
10467+
from torch._subclasses.fake_tensor import FakeTensor
10468+
from torch._subclasses.functional_tensor import FunctionalTensor
10469+
from torch.nested._internal.nested_tensor import (
10470+
_tensor_symint_registry,
10471+
NestedTensor,
10472+
)
10473+
from torch.nested._internal.ops import extract_kwargs
10474+
10475+
if x.layout != torch.jagged:
10476+
raise RuntimeError(
10477+
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10478+
"Please raise an issue on GitHub."
10479+
)
10480+
kwargs = extract_kwargs(x)
10481+
values = x._values
10482+
lengths = x._lengths
10483+
offsets = x._offsets
10484+
kwargs["offsets"] = set_(offsets)
10485+
if lengths is not None:
10486+
kwargs["lengths"] = set_(lengths)
10487+
ragged_source = lengths
10488+
else:
10489+
ragged_source = offsets
10490+
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
10491+
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
10492+
from torch._subclasses.functional_tensor import (
10493+
mb_unwrap_functional_tensor,
10494+
)
10495+
10496+
# Temporary hack until we have the union find
10497+
tgt = mb_unwrap_functional_tensor(new_thing)
10498+
src = mb_unwrap_functional_tensor(ragged_source)
10499+
tgt.nested_int_memo = src.nested_int_memo
10500+
else:
10501+
_tensor_symint_registry[new_thing] = _tensor_symint_registry[
10502+
ragged_source
10503+
]
10504+
10505+
return NestedTensor(
10506+
set_(values),
10507+
**kwargs,
10508+
)
1046210509
storage_offset = x.storage_offset()
1046310510
stride = x.stride()
10464-
return torch.empty_like(x, device=device).set_(
10511+
return x.new_empty(0, device=device).set_(
1046510512
untyped_storage,
1046610513
size=x.shape,
1046710514
stride=stride,
@@ -10473,7 +10520,14 @@ def set_(x):
1047310520
)
1047410521
result._consolidated = {"storage": storage_cast}
1047510522
if "metadata" in self._consolidated:
10476-
result._consolidated["metadata"] = deepcopy(self._consolidated["metadata"])
10523+
# faster than deepcopy
10524+
def copy_dict(d):
10525+
return {
10526+
k: v if not isinstance(v, dict) else copy_dict(v)
10527+
for k, v in d.items()
10528+
}
10529+
10530+
result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
1047710531
if non_blocking in (False, None):
1047810532
if device.type == "cuda" and non_blocking is False:
1047910533
# sending to CUDA force sync

tensordict/utils.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,16 +1540,26 @@ def assert_close(
15401540
elif not isinstance(input1, torch.Tensor):
15411541
continue
15421542
if input1.is_nested:
1543-
input1 = input1._base
1544-
input2 = input2._base
1545-
mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
1543+
input1v = input1.values()
1544+
input2v = input2.values()
1545+
mse = (input1v.to(torch.float) - input2v.to(torch.float)).pow(2).sum()
1546+
input1o = input1.offsets()
1547+
input2o = input2.offsets()
1548+
mse = mse + (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum()
1549+
else:
1550+
mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
15461551
mse = mse.div(input1.numel()).sqrt().item()
15471552

15481553
local_msg = f"key {key} does not match, got mse = {mse:4.4f}"
15491554
new_msg = ",\t".join([local_msg, msg]) if len(msg) else local_msg
1550-
torch.testing.assert_close(
1551-
input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg
1552-
)
1555+
if input1.is_nested:
1556+
torch.testing.assert_close(
1557+
input1v, input2v, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg
1558+
)
1559+
else:
1560+
torch.testing.assert_close(
1561+
input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg
1562+
)
15531563
local_msg = f"key {key} matches"
15541564
msg = "\t".join([local_msg, msg]) if len(msg) else local_msg
15551565

@@ -2650,3 +2660,36 @@ def parse_tensor_dict_string(s: str):
26502660
raise ValueError("Device not found in the string")
26512661
tensor_dict = TensorDict(fields, batch_size=torch.Size(batch_size), device=device)
26522662
return tensor_dict
2663+
2664+
2665+
def _rebuild_njt_from_njt(x, values, offsets, lengths):
2666+
from torch._subclasses.fake_tensor import FakeTensor
2667+
from torch._subclasses.functional_tensor import FunctionalTensor
2668+
from torch.nested._internal.nested_tensor import (
2669+
_tensor_symint_registry,
2670+
NestedTensor,
2671+
)
2672+
from torch.nested._internal.ops import extract_kwargs
2673+
2674+
kwargs = extract_kwargs(x)
2675+
kwargs["offsets"] = offsets
2676+
if x._lengths is not None:
2677+
kwargs["lengths"] = lengths
2678+
ragged_source = x._lengths
2679+
else:
2680+
ragged_source = x._offsets
2681+
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
2682+
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
2683+
from torch._subclasses.functional_tensor import mb_unwrap_functional_tensor
2684+
2685+
# Temporary hack until we have the union find
2686+
tgt = mb_unwrap_functional_tensor(new_thing)
2687+
src = mb_unwrap_functional_tensor(ragged_source)
2688+
tgt.nested_int_memo = src.nested_int_memo
2689+
else:
2690+
_tensor_symint_registry[new_thing] = _tensor_symint_registry[ragged_source]
2691+
2692+
return NestedTensor(
2693+
values,
2694+
**kwargs,
2695+
)

0 commit comments

Comments
 (0)