Skip to content

Commit 67d4c07

Browse files
author
Vincent Moens
committed
[Performance] Make _to_consolidated compatible with compile
ghstack-source-id: 0f01625 Pull Request resolved: #1041
1 parent fe6db77 commit 67d4c07

File tree

1 file changed

+139
-17
lines changed

1 file changed

+139
-17
lines changed

tensordict/base.py

Lines changed: 139 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
35213521

35223522
flat_size = []
35233523
start = 0
3524+
sorting_index = 0
35243525

35253526
def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
3526-
nonlocal start
3527+
nonlocal start, sorting_index
35273528
n = value.element_size() * value.numel()
35283529
if need_padding:
35293530
pad = n % 8
@@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
35413542
start,
35423543
stop,
35433544
pad,
3545+
flat_size[-1],
3546+
sorting_index,
35443547
)
3548+
sorting_index = sorting_index + 1
35453549
start = stop
35463550

35473551
def assign(
@@ -10395,6 +10399,7 @@ def to(self, *args, **kwargs) -> T:
1039510399
pin_memory=non_blocking_pin,
1039610400
num_threads=num_threads,
1039710401
non_blocking=non_blocking,
10402+
compilable=is_dynamo_compiling(),
1039810403
)
1039910404

1040010405
if non_blocking is None:
@@ -10452,14 +10457,42 @@ def to_pinmem(tensor, _to=to):
1045210457
self._sync_all()
1045310458
return result
1045410459

10455-
def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
10460+
def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking, compilable):
1045610461
if num_threads is None:
1045710462
# unspecified num_threads should mean 0
1045810463
num_threads = 0
10464+
1045910465
storage = self._consolidated["storage"]
10460-
if pin_memory:
10461-
storage = storage.pin_memory()
10462-
storage_cast = storage.to(device, non_blocking=True)
10466+
10467+
@torch.compiler.disable()
10468+
def to(storage):
10469+
if pin_memory:
10470+
storage = storage.pin_memory()
10471+
storage_cast = storage.to(device, non_blocking=True)
10472+
return storage_cast
10473+
storage_cast = to(storage)
10474+
10475+
if compilable:
10476+
result = self._to_consolidated_compile(device=device, num_threads=num_threads, storage_cast=storage_cast)
10477+
else:
10478+
result = self._to_consolidated_eager(device=device, num_threads=num_threads, storage_cast=storage_cast)
10479+
10480+
if non_blocking in (False, None):
10481+
if device.type == "cuda" and non_blocking is False:
10482+
# sending to CUDA force sync
10483+
cuda_device = device
10484+
elif storage.device.type == "cuda":
10485+
# sending from cuda: need sync unless intentionally not asked for
10486+
cuda_device = storage.device.type
10487+
else:
10488+
cuda_device = None
10489+
if cuda_device is not None:
10490+
torch.cuda.current_stream(cuda_device).synchronize()
10491+
10492+
return result
10493+
10494+
def _to_consolidated_eager(self, *, device, num_threads, storage_cast):
10495+
1046310496
untyped_storage = storage_cast.untyped_storage()
1046410497

1046510498
def set_(x):
@@ -10528,20 +10561,109 @@ def copy_dict(d):
1052810561
}
1052910562

1053010563
result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
10531-
if non_blocking in (False, None):
10532-
if device.type == "cuda" and non_blocking is False:
10533-
# sending to CUDA force sync
10534-
cuda_device = device
10535-
elif storage.device.type == "cuda":
10536-
# sending from cuda: need sync unless intentionally not asked for
10537-
cuda_device = storage.device.type
10538-
else:
10539-
cuda_device = None
10540-
if cuda_device is not None:
10541-
torch.cuda.current_stream(cuda_device).synchronize()
10542-
1054310564
return result
1054410565

10566+
@torch.compile(dynamic=True)
10567+
def _to_consolidated_compile(self, *, device, num_threads, storage_cast):
10568+
10569+
def get_l(metadata, lengths=None, pos=None, keys=None, prefix=()):
10570+
root = False
10571+
if lengths is None:
10572+
lengths = []
10573+
pos = []
10574+
keys = []
10575+
root = True
10576+
for k, v in metadata["leaves"].items():
10577+
lengths.append(v[-2])
10578+
pos.append(v[-1])
10579+
keys.append(prefix + (k,))
10580+
for k, d in metadata.items():
10581+
if "leaves" in d:
10582+
get_l(d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,))
10583+
if root:
10584+
# l = torch.empty(len(lengths), dtype=torch.long)
10585+
# l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
10586+
out0 = [None, ] * len(pos)
10587+
out1 = [None, ] * len(pos)
10588+
for p, l, k in zip(pos, lengths, keys):
10589+
out0[p] = k
10590+
out1[p] = l
10591+
return out0, out1
10592+
10593+
def split_storage(consolidated):
10594+
keys, splits = get_l(consolidated["metadata"])
10595+
return dict(zip(keys, consolidated["storage"].split(splits)))
10596+
10597+
if num_threads is None:
10598+
# unspecified num_threads should mean 0
10599+
num_threads = 0
10600+
10601+
_consolidated = {"storage": storage_cast}
10602+
if "metadata" in self._consolidated:
10603+
# faster than deepcopy
10604+
def copy_dict(d):
10605+
return {
10606+
k: v if not isinstance(v, dict) else copy_dict(v)
10607+
for k, v in d.items()
10608+
}
10609+
10610+
_consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
10611+
10612+
slice_map = split_storage(_consolidated)
10613+
10614+
def set_(name, x):
10615+
if not isinstance(name, tuple):
10616+
name = (name,)
10617+
if x.is_nested:
10618+
from torch._subclasses.fake_tensor import FakeTensor
10619+
from torch._subclasses.functional_tensor import FunctionalTensor
10620+
from torch.nested._internal.nested_tensor import (
10621+
_tensor_symint_registry,
10622+
NestedTensor,
10623+
)
10624+
from torch.nested._internal.ops import extract_kwargs
10625+
10626+
if x.layout != torch.jagged:
10627+
raise RuntimeError(
10628+
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10629+
"Please raise an issue on GitHub."
10630+
)
10631+
kwargs = extract_kwargs(x)
10632+
values = x._values
10633+
lengths = x._lengths
10634+
offsets = x._offsets
10635+
kwargs["offsets"] = slice_map[(*name[:-1], "<NJT_OFFSETS>"+name[-1],)].view(offsets.dtype).view(offsets.shape)
10636+
if lengths is not None:
10637+
kwargs["lengths"] = slice_map[(*name[:-1], "<NJT_LENGTHS>"+name[-1],)].view(lengths.dtype).view(lengths.shape)
10638+
ragged_source = lengths
10639+
else:
10640+
ragged_source = offsets
10641+
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
10642+
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
10643+
from torch._subclasses.functional_tensor import (
10644+
mb_unwrap_functional_tensor,
10645+
)
10646+
10647+
# Temporary hack until we have the union find
10648+
tgt = mb_unwrap_functional_tensor(new_thing)
10649+
src = mb_unwrap_functional_tensor(ragged_source)
10650+
tgt.nested_int_memo = src.nested_int_memo
10651+
else:
10652+
_tensor_symint_registry[new_thing] = _tensor_symint_registry[
10653+
ragged_source
10654+
]
10655+
10656+
return NestedTensor(
10657+
slice_map[(*name[:-1], "<NJT_VALUES>"+name[-1],)].view(values.dtype).view(values.shape),
10658+
**kwargs,
10659+
)
10660+
return slice_map[name].view(x.dtype).view(x.shape)
10661+
10662+
result = self._fast_apply(
10663+
set_, device=torch.device(device), num_threads=num_threads, named=True, nested_keys=True,
10664+
)
10665+
result._consolidated = _consolidated
10666+
return result
1054510667
def _sync_all(self):
1054610668
if _has_cuda:
1054710669
# TODO: dynamo doesn't like torch.cuda.is_initialized

0 commit comments

Comments
 (0)