Skip to content

Commit d4a9438

Browse files
author
Vincent Moens
committed
[Performance] Make _to_consolidated compatible with compile
ghstack-source-id: 317a4b9 Pull Request resolved: #1041
1 parent fe6db77 commit d4a9438

File tree

1 file changed

+176
-16
lines changed

1 file changed

+176
-16
lines changed

tensordict/base.py

Lines changed: 176 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
35213521

35223522
flat_size = []
35233523
start = 0
3524+
sorting_index = 0
35243525

35253526
def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
3526-
nonlocal start
3527+
nonlocal start, sorting_index
35273528
n = value.element_size() * value.numel()
35283529
if need_padding:
35293530
pad = n % 8
@@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
35413542
start,
35423543
stop,
35433544
pad,
3545+
flat_size[-1],
3546+
sorting_index,
35443547
)
3548+
sorting_index = sorting_index + 1
35453549
start = stop
35463550

35473551
def assign(
@@ -10395,6 +10399,7 @@ def to(self, *args, **kwargs) -> T:
1039510399
pin_memory=non_blocking_pin,
1039610400
num_threads=num_threads,
1039710401
non_blocking=non_blocking,
10402+
compilable=is_dynamo_compiling(),
1039810403
)
1039910404

1040010405
if non_blocking is None:
@@ -10452,14 +10457,49 @@ def to_pinmem(tensor, _to=to):
1045210457
self._sync_all()
1045310458
return result
1045410459

10455-
def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
10460+
def _to_consolidated(
10461+
self, *, device, pin_memory, num_threads, non_blocking, compilable
10462+
):
1045610463
if num_threads is None:
1045710464
# unspecified num_threads should mean 0
1045810465
num_threads = 0
10466+
1045910467
storage = self._consolidated["storage"]
10460-
if pin_memory:
10461-
storage = storage.pin_memory()
10462-
storage_cast = storage.to(device, non_blocking=True)
10468+
10469+
@torch.compiler.disable()
10470+
def to(storage):
10471+
if pin_memory:
10472+
storage = storage.pin_memory()
10473+
storage_cast = storage.to(device, non_blocking=True)
10474+
return storage_cast
10475+
10476+
storage_cast = to(storage)
10477+
10478+
if compilable:
10479+
result = self._to_consolidated_compile(
10480+
device=device, num_threads=num_threads, storage_cast=storage_cast
10481+
)
10482+
else:
10483+
result = self._to_consolidated_eager(
10484+
device=device, num_threads=num_threads, storage_cast=storage_cast
10485+
)
10486+
10487+
if non_blocking in (False, None):
10488+
if device.type == "cuda" and non_blocking is False:
10489+
# sending to CUDA force sync
10490+
cuda_device = device
10491+
elif storage.device.type == "cuda":
10492+
# sending from cuda: need sync unless intentionally not asked for
10493+
cuda_device = storage.device.type
10494+
else:
10495+
cuda_device = None
10496+
if cuda_device is not None:
10497+
torch.cuda.current_stream(cuda_device).synchronize()
10498+
10499+
return result
10500+
10501+
def _to_consolidated_eager(self, *, device, num_threads, storage_cast):
10502+
1046310503
untyped_storage = storage_cast.untyped_storage()
1046410504

1046510505
def set_(x):
@@ -10528,18 +10568,138 @@ def copy_dict(d):
1052810568
}
1052910569

1053010570
result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
10531-
if non_blocking in (False, None):
10532-
if device.type == "cuda" and non_blocking is False:
10533-
# sending to CUDA force sync
10534-
cuda_device = device
10535-
elif storage.device.type == "cuda":
10536-
# sending from cuda: need sync unless intentionally not asked for
10537-
cuda_device = storage.device.type
10538-
else:
10539-
cuda_device = None
10540-
if cuda_device is not None:
10541-
torch.cuda.current_stream(cuda_device).synchronize()
10571+
return result
10572+
10573+
def _to_consolidated_compile(self, *, device, num_threads, storage_cast):
10574+
10575+
def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()):
10576+
root = False
10577+
if lengths is None:
10578+
lengths = []
10579+
pos = []
10580+
keys = []
10581+
root = True
10582+
for k, v in metadata["leaves"].items():
10583+
lengths.append(v[-2])
10584+
pos.append(v[-1])
10585+
keys.append(prefix + (k,))
10586+
for k, d in metadata.items():
10587+
if "leaves" in d:
10588+
get_tensors_length(
10589+
d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,)
10590+
)
10591+
if root:
10592+
# l = torch.empty(len(lengths), dtype=torch.long)
10593+
# l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
10594+
out0 = [
10595+
None,
10596+
] * len(pos)
10597+
out1 = [
10598+
None,
10599+
] * len(pos)
10600+
for p, l, k in zip(pos, lengths, keys):
10601+
out0[p] = k
10602+
out1[p] = l
10603+
return out0, out1
10604+
10605+
def split_storage(consolidated):
10606+
keys, splits = get_tensors_length(consolidated["metadata"])
10607+
return dict(zip(keys, consolidated["storage"].split(splits)))
10608+
10609+
if num_threads is None:
10610+
# unspecified num_threads should mean 0
10611+
num_threads = 0
10612+
10613+
_consolidated = {"storage": storage_cast}
10614+
if "metadata" in self._consolidated:
10615+
# faster than deepcopy
10616+
def copy_dict(d):
10617+
return {
10618+
k: v if not isinstance(v, dict) else copy_dict(v)
10619+
for k, v in d.items()
10620+
}
10621+
10622+
_consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
10623+
10624+
slice_map = split_storage(_consolidated)
10625+
10626+
def view_as(src, dest):
10627+
return src.view(dest.dtype)[: dest.numel()].view(dest.shape)
1054210628

10629+
def set_(name, x):
10630+
if not isinstance(name, tuple):
10631+
name = (name,)
10632+
if x.is_nested:
10633+
from torch._subclasses.fake_tensor import FakeTensor
10634+
from torch._subclasses.functional_tensor import FunctionalTensor
10635+
from torch.nested._internal.nested_tensor import (
10636+
_tensor_symint_registry,
10637+
NestedTensor,
10638+
)
10639+
from torch.nested._internal.ops import extract_kwargs
10640+
10641+
if x.layout != torch.jagged:
10642+
raise RuntimeError(
10643+
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10644+
"Please raise an issue on GitHub."
10645+
)
10646+
kwargs = extract_kwargs(x)
10647+
values = x._values
10648+
lengths = x._lengths
10649+
offsets = x._offsets
10650+
storage_offsets = slice_map[
10651+
(
10652+
*name[:-1],
10653+
"<NJT_OFFSETS>" + name[-1],
10654+
)
10655+
]
10656+
kwargs["offsets"] = view_as(storage_offsets, offsets)
10657+
if lengths is not None:
10658+
storage_lengths = slice_map[
10659+
(
10660+
*name[:-1],
10661+
"<NJT_LENGTHS>" + name[-1],
10662+
)
10663+
]
10664+
kwargs["lengths"] = view_as(storage_lengths, lengths)
10665+
ragged_source = lengths
10666+
else:
10667+
ragged_source = offsets
10668+
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
10669+
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
10670+
from torch._subclasses.functional_tensor import (
10671+
mb_unwrap_functional_tensor,
10672+
)
10673+
10674+
# Temporary hack until we have the union find
10675+
tgt = mb_unwrap_functional_tensor(new_thing)
10676+
src = mb_unwrap_functional_tensor(ragged_source)
10677+
tgt.nested_int_memo = src.nested_int_memo
10678+
else:
10679+
_tensor_symint_registry[new_thing] = _tensor_symint_registry[
10680+
ragged_source
10681+
]
10682+
10683+
storage_values = slice_map[
10684+
(
10685+
*name[:-1],
10686+
"<NJT_VALUES>" + name[-1],
10687+
)
10688+
]
10689+
return NestedTensor(
10690+
view_as(storage_values, values),
10691+
**kwargs,
10692+
)
10693+
return view_as(slice_map[name], x)
10694+
10695+
result = self._fast_apply(
10696+
set_,
10697+
device=torch.device(device),
10698+
num_threads=num_threads,
10699+
named=True,
10700+
nested_keys=True,
10701+
)
10702+
result._consolidated = _consolidated
1054310703
return result
1054410704

1054510705
def _sync_all(self):

0 commit comments

Comments
 (0)