@@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
35213521
35223522 flat_size = []
35233523 start = 0
3524+ sorting_index = 0
35243525
35253526 def add_single_value (value , key , metadata_dict , dtype , shape , flat_size ):
3526- nonlocal start
3527+ nonlocal start , sorting_index
35273528 n = value .element_size () * value .numel ()
35283529 if need_padding :
35293530 pad = n % 8
@@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
35413542 start ,
35423543 stop ,
35433544 pad ,
3545+ flat_size [- 1 ],
3546+ sorting_index ,
35443547 )
3548+ sorting_index = sorting_index + 1
35453549 start = stop
35463550
35473551 def assign (
@@ -10395,6 +10399,7 @@ def to(self, *args, **kwargs) -> T:
1039510399 pin_memory = non_blocking_pin ,
1039610400 num_threads = num_threads ,
1039710401 non_blocking = non_blocking ,
10402+ compilable = is_dynamo_compiling (),
1039810403 )
1039910404
1040010405 if non_blocking is None :
@@ -10452,14 +10457,49 @@ def to_pinmem(tensor, _to=to):
1045210457 self ._sync_all ()
1045310458 return result
1045410459
10455- def _to_consolidated (self , * , device , pin_memory , num_threads , non_blocking ):
10460+ def _to_consolidated (
10461+ self , * , device , pin_memory , num_threads , non_blocking , compilable
10462+ ):
1045610463 if num_threads is None :
1045710464 # unspecified num_threads should mean 0
1045810465 num_threads = 0
10466+
1045910467 storage = self ._consolidated ["storage" ]
10460- if pin_memory :
10461- storage = storage .pin_memory ()
10462- storage_cast = storage .to (device , non_blocking = True )
10468+
10469+ @torch .compiler .disable ()
10470+ def to (storage ):
10471+ if pin_memory :
10472+ storage = storage .pin_memory ()
10473+ storage_cast = storage .to (device , non_blocking = True )
10474+ return storage_cast
10475+
10476+ storage_cast = to (storage )
10477+
10478+ if compilable :
10479+ result = self ._to_consolidated_compile (
10480+ device = device , num_threads = num_threads , storage_cast = storage_cast
10481+ )
10482+ else :
10483+ result = self ._to_consolidated_eager (
10484+ device = device , num_threads = num_threads , storage_cast = storage_cast
10485+ )
10486+
10487+ if non_blocking in (False , None ):
10488+ if device .type == "cuda" and non_blocking is False :
10489+ # sending to CUDA force sync
10490+ cuda_device = device
10491+ elif storage .device .type == "cuda" :
10492+ # sending from cuda: need sync unless intentionally not asked for
10493+ cuda_device = storage .device .type
10494+ else :
10495+ cuda_device = None
10496+ if cuda_device is not None :
10497+ torch .cuda .current_stream (cuda_device ).synchronize ()
10498+
10499+ return result
10500+
10501+ def _to_consolidated_eager (self , * , device , num_threads , storage_cast ):
10502+
1046310503 untyped_storage = storage_cast .untyped_storage ()
1046410504
1046510505 def set_ (x ):
@@ -10528,18 +10568,138 @@ def copy_dict(d):
1052810568 }
1052910569
1053010570 result ._consolidated ["metadata" ] = copy_dict (self ._consolidated ["metadata" ])
10531- if non_blocking in (False , None ):
10532- if device .type == "cuda" and non_blocking is False :
10533- # sending to CUDA force sync
10534- cuda_device = device
10535- elif storage .device .type == "cuda" :
10536- # sending from cuda: need sync unless intentionally not asked for
10537- cuda_device = storage .device .type
10538- else :
10539- cuda_device = None
10540- if cuda_device is not None :
10541- torch .cuda .current_stream (cuda_device ).synchronize ()
10571+ return result
10572+
10573+ def _to_consolidated_compile (self , * , device , num_threads , storage_cast ):
10574+
10575+ def get_tensors_length (metadata , lengths = None , pos = None , keys = None , prefix = ()):
10576+ root = False
10577+ if lengths is None :
10578+ lengths = []
10579+ pos = []
10580+ keys = []
10581+ root = True
10582+ for k , v in metadata ["leaves" ].items ():
10583+ lengths .append (v [- 2 ])
10584+ pos .append (v [- 1 ])
10585+ keys .append (prefix + (k ,))
10586+ for k , d in metadata .items ():
10587+ if "leaves" in d :
10588+ get_tensors_length (
10589+ d , lengths = lengths , pos = pos , keys = keys , prefix = prefix + (k ,)
10590+ )
10591+ if root :
10592+ # l = torch.empty(len(lengths), dtype=torch.long)
10593+ # l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
10594+ out0 = [
10595+ None ,
10596+ ] * len (pos )
10597+ out1 = [
10598+ None ,
10599+ ] * len (pos )
10600+ for p , l , k in zip (pos , lengths , keys ):
10601+ out0 [p ] = k
10602+ out1 [p ] = l
10603+ return out0 , out1
10604+
10605+ def split_storage (consolidated ):
10606+ keys , splits = get_tensors_length (consolidated ["metadata" ])
10607+ return dict (zip (keys , consolidated ["storage" ].split (splits )))
10608+
10609+ if num_threads is None :
10610+ # unspecified num_threads should mean 0
10611+ num_threads = 0
10612+
10613+ _consolidated = {"storage" : storage_cast }
10614+ if "metadata" in self ._consolidated :
10615+ # faster than deepcopy
10616+ def copy_dict (d ):
10617+ return {
10618+ k : v if not isinstance (v , dict ) else copy_dict (v )
10619+ for k , v in d .items ()
10620+ }
10621+
10622+ _consolidated ["metadata" ] = copy_dict (self ._consolidated ["metadata" ])
10623+
10624+ slice_map = split_storage (_consolidated )
10625+
10626+ def view_as (src , dest ):
10627+ return src .view (dest .dtype )[: dest .numel ()].view (dest .shape )
1054210628
10629+ def set_ (name , x ):
10630+ if not isinstance (name , tuple ):
10631+ name = (name ,)
10632+ if x .is_nested :
10633+ from torch ._subclasses .fake_tensor import FakeTensor
10634+ from torch ._subclasses .functional_tensor import FunctionalTensor
10635+ from torch .nested ._internal .nested_tensor import (
10636+ _tensor_symint_registry ,
10637+ NestedTensor ,
10638+ )
10639+ from torch .nested ._internal .ops import extract_kwargs
10640+
10641+ if x .layout != torch .jagged :
10642+ raise RuntimeError (
10643+ "to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10644+ "Please raise an issue on GitHub."
10645+ )
10646+ kwargs = extract_kwargs (x )
10647+ values = x ._values
10648+ lengths = x ._lengths
10649+ offsets = x ._offsets
10650+ storage_offsets = slice_map [
10651+ (
10652+ * name [:- 1 ],
10653+ "<NJT_OFFSETS>" + name [- 1 ],
10654+ )
10655+ ]
10656+ kwargs ["offsets" ] = view_as (storage_offsets , offsets )
10657+ if lengths is not None :
10658+ storage_lengths = slice_map [
10659+ (
10660+ * name [:- 1 ],
10661+ "<NJT_LENGTHS>" + name [- 1 ],
10662+ )
10663+ ]
10664+ kwargs ["lengths" ] = view_as (storage_lengths , lengths )
10665+ ragged_source = lengths
10666+ else :
10667+ ragged_source = offsets
10668+ new_thing = kwargs .get ("lengths" , kwargs .get ("offsets" ))
10669+ if isinstance (new_thing , (FakeTensor , FunctionalTensor )):
10670+ from torch ._subclasses .functional_tensor import (
10671+ mb_unwrap_functional_tensor ,
10672+ )
10673+
10674+ # Temporary hack until we have the union find
10675+ tgt = mb_unwrap_functional_tensor (new_thing )
10676+ src = mb_unwrap_functional_tensor (ragged_source )
10677+ tgt .nested_int_memo = src .nested_int_memo
10678+ else :
10679+ _tensor_symint_registry [new_thing ] = _tensor_symint_registry [
10680+ ragged_source
10681+ ]
10682+
10683+ storage_values = slice_map [
10684+ (
10685+ * name [:- 1 ],
10686+ "<NJT_VALUES>" + name [- 1 ],
10687+ )
10688+ ]
10689+ return NestedTensor (
10690+ view_as (storage_values , values ),
10691+ ** kwargs ,
10692+ )
10693+ return view_as (slice_map [name ], x )
10694+
10695+ result = self ._fast_apply (
10696+ set_ ,
10697+ device = torch .device (device ),
10698+ num_threads = num_threads ,
10699+ named = True ,
10700+ nested_keys = True ,
10701+ )
10702+ result ._consolidated = _consolidated
1054310703 return result
1054410704
1054510705 def _sync_all (self ):
0 commit comments