@@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
35213521
35223522 flat_size = []
35233523 start = 0
3524+ sorting_index = 0
35243525
35253526 def add_single_value (value , key , metadata_dict , dtype , shape , flat_size ):
3526- nonlocal start
3527+ nonlocal start , sorting_index
35273528 n = value .element_size () * value .numel ()
35283529 if need_padding :
35293530 pad = n % 8
@@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
35413542 start ,
35423543 stop ,
35433544 pad ,
3545+ flat_size [- 1 ],
3546+ sorting_index ,
35443547 )
3548+ sorting_index = sorting_index + 1
35453549 start = stop
35463550
35473551 def assign (
@@ -10390,7 +10394,7 @@ def to(self, *args, **kwargs) -> T:
1039010394 return result
1039110395
1039210396 if self .is_consolidated () and dtype is None :
10393- return self ._to_consolidated (
10397+ return self ._to_consolidated_compile (
1039410398 device = device ,
1039510399 pin_memory = non_blocking_pin ,
1039610400 num_threads = num_threads ,
@@ -10542,6 +10546,124 @@ def copy_dict(d):
1054210546
1054310547 return result
1054410548
10549+ def _to_consolidated_compile (self , * , device , pin_memory , num_threads , non_blocking ):
10550+
10551+ def get_l (metadata , lengths = None , pos = None , keys = None , prefix = ()):
10552+ root = False
10553+ if lengths is None :
10554+ lengths = []
10555+ pos = []
10556+ keys = []
10557+ root = True
10558+ for k , v in metadata ["leaves" ].items ():
10559+ lengths .append (v [- 2 ])
10560+ pos .append (v [- 1 ])
10561+ keys .append (prefix + (k ,))
10562+ for k , d in metadata .items ():
10563+ if "leaves" in d :
10564+ get_l (d , lengths = lengths , pos = pos , keys = keys , prefix = prefix + (k ,))
10565+ if root :
10566+ # l = torch.empty(len(lengths), dtype=torch.long)
10567+ # l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
10568+ out0 = [None , ] * len (pos )
10569+ out1 = [None , ] * len (pos )
10570+ for p , l , k in zip (pos , lengths , keys ):
10571+ out0 [p ] = k
10572+ out1 [p ] = l
10573+ return out0 , out1
10574+
10575+ def split_storage (consolidated ):
10576+ keys , splits = get_l (consolidated ["metadata" ])
10577+ return dict (zip (keys , consolidated ["storage" ].split (splits )))
10578+
10579+ if num_threads is None :
10580+ # unspecified num_threads should mean 0
10581+ num_threads = 0
10582+ storage = self ._consolidated ["storage" ]
10583+ if pin_memory :
10584+ storage = storage .pin_memory ()
10585+ storage_cast = storage .to (device , non_blocking = True )
10586+
10587+ _consolidated = {"storage" : storage_cast }
10588+ if "metadata" in self ._consolidated :
10589+ # faster than deepcopy
10590+ def copy_dict (d ):
10591+ return {
10592+ k : v if not isinstance (v , dict ) else copy_dict (v )
10593+ for k , v in d .items ()
10594+ }
10595+
10596+ _consolidated ["metadata" ] = copy_dict (self ._consolidated ["metadata" ])
10597+
10598+ slice_map = split_storage (_consolidated )
10599+
10600+ def set_ (name , x ):
10601+ if not isinstance (name , tuple ):
10602+ name = (name ,)
10603+ if x .is_nested :
10604+ from torch ._subclasses .fake_tensor import FakeTensor
10605+ from torch ._subclasses .functional_tensor import FunctionalTensor
10606+ from torch .nested ._internal .nested_tensor import (
10607+ _tensor_symint_registry ,
10608+ NestedTensor ,
10609+ )
10610+ from torch .nested ._internal .ops import extract_kwargs
10611+
10612+ if x .layout != torch .jagged :
10613+ raise RuntimeError (
10614+ "to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10615+ "Please raise an issue on GitHub."
10616+ )
10617+ kwargs = extract_kwargs (x )
10618+ values = x ._values
10619+ lengths = x ._lengths
10620+ offsets = x ._offsets
10621+ kwargs ["offsets" ] = slice_map [(* name [:- 1 ], "<NJT_OFFSETS>" + name [- 1 ],)].view (offsets .dtype ).view (offsets .shape )
10622+ if lengths is not None :
10623+ kwargs ["lengths" ] = slice_map [(* name [:- 1 ], "<NJT_LENGTHS>" + name [- 1 ],)].view (lengths .dtype ).view (lengths .shape )
10624+ ragged_source = lengths
10625+ else :
10626+ ragged_source = offsets
10627+ new_thing = kwargs .get ("lengths" , kwargs .get ("offsets" ))
10628+ if isinstance (new_thing , (FakeTensor , FunctionalTensor )):
10629+ from torch ._subclasses .functional_tensor import (
10630+ mb_unwrap_functional_tensor ,
10631+ )
10632+
10633+ # Temporary hack until we have the union find
10634+ tgt = mb_unwrap_functional_tensor (new_thing )
10635+ src = mb_unwrap_functional_tensor (ragged_source )
10636+ tgt .nested_int_memo = src .nested_int_memo
10637+ else :
10638+ _tensor_symint_registry [new_thing ] = _tensor_symint_registry [
10639+ ragged_source
10640+ ]
10641+
10642+ return NestedTensor (
10643+ slice_map [(* name [:- 1 ], "<NJT_VALUES>" + name [- 1 ],)].view (values .dtype ).view (values .shape ),
10644+ ** kwargs ,
10645+ )
10646+ return slice_map [name ].view (x .dtype ).view (x .shape )
10647+
10648+ result = self ._fast_apply (
10649+ set_ , device = torch .device (device ), num_threads = num_threads , named = True , nested_keys = True ,
10650+ )
10651+ result ._consolidated = _consolidated
10652+
10653+ if non_blocking in (False , None ):
10654+ if device .type == "cuda" and non_blocking is False :
10655+ # sending to CUDA force sync
10656+ cuda_device = device
10657+ elif storage .device .type == "cuda" :
10658+ # sending from cuda: need sync unless intentionally not asked for
10659+ cuda_device = storage .device .type
10660+ else :
10661+ cuda_device = None
10662+ if cuda_device is not None :
10663+ torch .cuda .current_stream (cuda_device ).synchronize ()
10664+
10665+ return result
10666+
1054510667 def _sync_all (self ):
1054610668 if _has_cuda :
1054710669 # TODO: dynamo doesn't like torch.cuda.is_initialized
0 commit comments