7272 _split_tensordict ,
7373 _td_fields ,
7474 _unravel_key_to_tuple ,
75- _zip_strict ,
75+ _zip_strict ,_to_escape_compile ,
7676 cache ,
7777 convert_ellipsis_to_idx ,
7878 DeviceType ,
@@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
35213521
35223522 flat_size = []
35233523 start = 0
3524+ sorting_index = 0
35243525
35253526 def add_single_value (value , key , metadata_dict , dtype , shape , flat_size ):
3526- nonlocal start
3527+ nonlocal start , sorting_index
35273528 n = value .element_size () * value .numel ()
35283529 if need_padding :
35293530 pad = n % 8
@@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
35413542 start ,
35423543 stop ,
35433544 pad ,
3545+ flat_size [- 1 ],
3546+ sorting_index ,
35443547 )
3548+ sorting_index = sorting_index + 1
35453549 start = stop
35463550
35473551 def assign (
@@ -10441,6 +10445,7 @@ def to(self, *args, **kwargs) -> T:
1044110445 pin_memory = non_blocking_pin ,
1044210446 num_threads = num_threads ,
1044310447 non_blocking = non_blocking ,
10448+ compilable = is_dynamo_compiling (),
1044410449 )
1044510450
1044610451 if non_blocking is None :
@@ -10498,14 +10503,42 @@ def to_pinmem(tensor, _to=to):
1049810503 self ._sync_all ()
1049910504 return result
1050010505
10501- def _to_consolidated (self , * , device , pin_memory , num_threads , non_blocking ):
10506+ def _to_consolidated (
10507+ self , * , device , pin_memory , num_threads , non_blocking , compilable
10508+ ):
1050210509 if num_threads is None :
1050310510 # unspecified num_threads should mean 0
1050410511 num_threads = 0
10512+
1050510513 storage = self ._consolidated ["storage" ]
10506- if pin_memory :
10507- storage = storage .pin_memory ()
10508- storage_cast = storage .to (device , non_blocking = True )
10514+
10515+ storage_cast = _to_escape_compile (storage , device = device , pin_memory = pin_memory )
10516+
10517+ if compilable :
10518+ result = self ._to_consolidated_compile (
10519+ device = device , num_threads = num_threads , storage_cast = storage_cast
10520+ )
10521+ else :
10522+ result = self ._to_consolidated_eager (
10523+ device = device , num_threads = num_threads , storage_cast = storage_cast
10524+ )
10525+
10526+ if non_blocking in (False , None ):
10527+ if device .type == "cuda" and non_blocking is False :
10528+ # sending to CUDA force sync
10529+ cuda_device = device
10530+ elif storage .device .type == "cuda" :
10531+ # sending from cuda: need sync unless intentionally not asked for
10532+ cuda_device = storage .device .type
10533+ else :
10534+ cuda_device = None
10535+ if cuda_device is not None :
10536+ torch .cuda .current_stream (cuda_device ).synchronize ()
10537+
10538+ return result
10539+
10540+ def _to_consolidated_eager (self , * , device , num_threads , storage_cast ):
10541+
1050910542 untyped_storage = storage_cast .untyped_storage ()
1051010543
1051110544 def set_ (x ):
@@ -10574,18 +10607,138 @@ def copy_dict(d):
1057410607 }
1057510608
1057610609 result ._consolidated ["metadata" ] = copy_dict (self ._consolidated ["metadata" ])
10577- if non_blocking in (False , None ):
10578- if device .type == "cuda" and non_blocking is False :
10579- # sending to CUDA force sync
10580- cuda_device = device
10581- elif storage .device .type == "cuda" :
10582- # sending from cuda: need sync unless intentionally not asked for
10583- cuda_device = storage .device .type
10584- else :
10585- cuda_device = None
10586- if cuda_device is not None :
10587- torch .cuda .current_stream (cuda_device ).synchronize ()
10610+ return result
10611+
10612+ def _to_consolidated_compile (self , * , device , num_threads , storage_cast ):
10613+
10614+ def get_tensors_length (metadata , lengths = None , pos = None , keys = None , prefix = ()):
10615+ root = False
10616+ if lengths is None :
10617+ lengths = []
10618+ pos = []
10619+ keys = []
10620+ root = True
10621+ for k , v in metadata ["leaves" ].items ():
10622+ lengths .append (v [- 2 ])
10623+ pos .append (v [- 1 ])
10624+ keys .append (prefix + (k ,))
10625+ for k , d in metadata .items ():
10626+ if "leaves" in d :
10627+ get_tensors_length (
10628+ d , lengths = lengths , pos = pos , keys = keys , prefix = prefix + (k ,)
10629+ )
10630+ if root :
10631+ # l = torch.empty(len(lengths), dtype=torch.long)
10632+ # l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
10633+ out0 = [
10634+ None ,
10635+ ] * len (pos )
10636+ out1 = [
10637+ None ,
10638+ ] * len (pos )
10639+ for p , l , k in zip (pos , lengths , keys ):
10640+ out0 [p ] = k
10641+ out1 [p ] = l
10642+ return out0 , out1
10643+
10644+ def split_storage (consolidated ):
10645+ keys , splits = get_tensors_length (consolidated ["metadata" ])
10646+ return dict (zip (keys , consolidated ["storage" ].split (splits )))
10647+
10648+ if num_threads is None :
10649+ # unspecified num_threads should mean 0
10650+ num_threads = 0
10651+
10652+ _consolidated = {"storage" : storage_cast }
10653+ if "metadata" in self ._consolidated :
10654+ # faster than deepcopy
10655+ def copy_dict (d ):
10656+ return {
10657+ k : v if not isinstance (v , dict ) else copy_dict (v )
10658+ for k , v in d .items ()
10659+ }
10660+
10661+ _consolidated ["metadata" ] = copy_dict (self ._consolidated ["metadata" ])
10662+
10663+ slice_map = split_storage (_consolidated )
10664+
10665+ def view_as (src , dest ):
10666+ return src .view (dest .dtype )[: dest .numel ()].view (dest .shape )
1058810667
10668+ def set_ (name , x ):
10669+ if not isinstance (name , tuple ):
10670+ name = (name ,)
10671+ if x .is_nested :
10672+ from torch ._subclasses .fake_tensor import FakeTensor
10673+ from torch ._subclasses .functional_tensor import FunctionalTensor
10674+ from torch .nested ._internal .nested_tensor import (
10675+ _tensor_symint_registry ,
10676+ NestedTensor ,
10677+ )
10678+ from torch .nested ._internal .ops import extract_kwargs
10679+
10680+ if x .layout != torch .jagged :
10681+ raise RuntimeError (
10682+ "to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10683+ "Please raise an issue on GitHub."
10684+ )
10685+ kwargs = extract_kwargs (x )
10686+ values = x ._values
10687+ lengths = x ._lengths
10688+ offsets = x ._offsets
10689+ storage_offsets = slice_map [
10690+ (
10691+ * name [:- 1 ],
10692+ "<NJT_OFFSETS>" + name [- 1 ],
10693+ )
10694+ ]
10695+ kwargs ["offsets" ] = view_as (storage_offsets , offsets )
10696+ if lengths is not None :
10697+ storage_lengths = slice_map [
10698+ (
10699+ * name [:- 1 ],
10700+ "<NJT_LENGTHS>" + name [- 1 ],
10701+ )
10702+ ]
10703+ kwargs ["lengths" ] = view_as (storage_lengths , lengths )
10704+ ragged_source = lengths
10705+ else :
10706+ ragged_source = offsets
10707+ new_thing = kwargs .get ("lengths" , kwargs .get ("offsets" ))
10708+ if isinstance (new_thing , (FakeTensor , FunctionalTensor )):
10709+ from torch ._subclasses .functional_tensor import (
10710+ mb_unwrap_functional_tensor ,
10711+ )
10712+
10713+ # Temporary hack until we have the union find
10714+ tgt = mb_unwrap_functional_tensor (new_thing )
10715+ src = mb_unwrap_functional_tensor (ragged_source )
10716+ tgt .nested_int_memo = src .nested_int_memo
10717+ else :
10718+ _tensor_symint_registry [new_thing ] = _tensor_symint_registry [
10719+ ragged_source
10720+ ]
10721+
10722+ storage_values = slice_map [
10723+ (
10724+ * name [:- 1 ],
10725+ "<NJT_VALUES>" + name [- 1 ],
10726+ )
10727+ ]
10728+ return NestedTensor (
10729+ view_as (storage_values , values ),
10730+ ** kwargs ,
10731+ )
10732+ return view_as (slice_map [name ], x )
10733+
10734+ result = self ._fast_apply (
10735+ set_ ,
10736+ device = torch .device (device ),
10737+ num_threads = num_threads ,
10738+ named = True ,
10739+ nested_keys = True ,
10740+ )
10741+ result ._consolidated = _consolidated
1058910742 return result
1059010743
1059110744 def _sync_all (self ):
0 commit comments