2020from collections .abc import MutableMapping
2121
2222from concurrent .futures import Future , ThreadPoolExecutor , wait
23- from copy import copy , deepcopy
23+ from copy import copy
2424from functools import partial , wraps
2525from pathlib import Path
2626from textwrap import indent
6666 _prefix_last_key ,
6767 _proc_init ,
6868 _prune_selected_keys ,
69+ _rebuild_njt_from_njt ,
6970 _set_max_batch_size ,
7071 _shape ,
7172 _split_tensordict ,
@@ -3591,7 +3592,7 @@ def assign(
35913592 if getattr (value , "is_nested" , False ):
35923593 if value .layout is torch .jagged :
35933594 # Get the values
3594- values = value .values ()
3595+ values = value ._values
35953596 shape = [v if isinstance (v , int ) else - 1 for v in values .shape ]
35963597 # Get the offsets
35973598 offsets = value ._offsets
@@ -3602,10 +3603,14 @@ def assign(
36023603 # We will rely on the fact that the writing order is preserved in python dict
36033604 # (since python 3.7). Later, we will read the NJT then the NJT offset in that order
36043605 # to do the allocation.
3605- flat_key_values [_prefix_last_key (total_key , "<NJT>" )] = values
3606+ flat_key_values [_prefix_last_key (total_key , "<NJT>" )] = value
3607+ flat_size .append (0 )
3608+ flat_key_values [_prefix_last_key (total_key , "<NJT_VALUES>" )] = (
3609+ values
3610+ )
36063611 add_single_value (
36073612 values ,
3608- _prefix_last_key (key , "<NJT >" ),
3613+ _prefix_last_key (key , "<NJT_VALUES >" ),
36093614 metadata_dict ,
36103615 values .dtype ,
36113616 shape ,
@@ -3811,12 +3816,14 @@ def assign(
38113816 start ,
38123817 stop ,
38133818 njts ,
3814- njts_offsets ,
3815- njts_lengths ,
38163819 storage = storage ,
38173820 non_blocking = non_blocking ,
38183821 ):
3822+ """Reads a slice of the storage and assigns the resulting tensor in flat_dict."""
38193823 # v may need padding
3824+ if k [- 1 ].startswith ("<NJT>" ):
3825+ njts [k ] = v
3826+ return
38203827 v_pad = v .view (- 1 ).view (torch .uint8 )
38213828 exp_length = stop - start
38223829 pad = exp_length - v_pad .numel ()
@@ -3830,17 +3837,9 @@ def assign(
38303837 if pad :
38313838 new_v = new_v [: v .numel ()]
38323839 new_v = new_v .view (shape )
3833- if k [- 1 ].startswith ("<NJT>" ):
3834- njts [k ] = new_v
3835- elif k [- 1 ].startswith ("<NJT_LENGTHS>" ):
3836- njts_lengths [k ] = new_v
3837- elif k [- 1 ].startswith ("<NJT_OFFSETS>" ):
3838- njts_offsets [k ] = new_v
38393840 flat_dict [k ] = new_v
38403841
38413842 njts = {}
3842- njts_offsets = {}
3843- njts_lengths = {}
38443843 if num_threads > 1 :
38453844 executor = ThreadPoolExecutor (num_threads )
38463845 r = []
@@ -3853,8 +3852,6 @@ def assign(
38533852 start = offsets [i ],
38543853 stop = offsets [i + 1 ],
38553854 njts = njts ,
3856- njts_offsets = njts_offsets ,
3857- njts_lengths = njts_lengths ,
38583855 )
38593856 )
38603857 if not return_early :
@@ -3872,25 +3869,25 @@ def assign(
38723869 start = offsets [i ],
38733870 stop = offsets [i + 1 ],
38743871 njts = njts ,
3875- njts_offsets = njts_offsets ,
3876- njts_lengths = njts_lengths ,
38773872 )
3878- for njt_key , njt_val in njts .items ():
3873+ for njt_key , njt in njts .items ():
3874+ newkey = njt_key [:- 1 ] + (njt_key [- 1 ].replace ("<NJT>" , "" ),)
3875+ njt_key_values = njt_key [:- 1 ] + (
3876+ njt_key [- 1 ].replace ("<NJT>" , "<NJT_VALUES>" ),
3877+ )
38793878 njt_key_offset = njt_key [:- 1 ] + (
38803879 njt_key [- 1 ].replace ("<NJT>" , "<NJT_OFFSETS>" ),
38813880 )
38823881 njt_key_lengths = njt_key [:- 1 ] + (
38833882 njt_key [- 1 ].replace ("<NJT>" , "<NJT_LENGTHS>" ),
38843883 )
3885- val = torch .nested .nested_tensor_from_jagged (
3886- njt_val ,
3887- offsets = flat_dict [njt_key_offset ],
3888- lengths = flat_dict .get (njt_key_lengths ),
3884+ val = _rebuild_njt_from_njt (
3885+ njt ,
3886+ values = flat_dict .pop (njt_key_values ),
3887+ offsets = flat_dict .pop (njt_key_offset ),
3888+ lengths = flat_dict .pop (njt_key_lengths , None ),
38893889 )
38903890 del flat_dict [njt_key ]
3891- del flat_dict [njt_key_offset ]
3892- flat_dict .pop (njt_key_lengths , None )
3893- newkey = njt_key [:- 1 ] + (njt_key [- 1 ].replace ("<NJT>" , "" ),)
38943891 flat_dict [newkey ] = val
38953892
38963893 if non_blocking and device .type != "cuda" :
@@ -3910,6 +3907,8 @@ def _view_and_pad(tensor):
39103907
39113908 items = []
39123909 for v in flat_dict .values ():
3910+ if v .is_nested :
3911+ continue
39133912 if v .device != storage .device :
39143913 v = v .to (storage .device , non_blocking = non_blocking )
39153914 stride = v .stride ()
@@ -3928,9 +3927,13 @@ def _view_and_pad(tensor):
39283927 flat_dict [k ] = view_old_as_new (v , oldv )
39293928 elif k [- 1 ].startswith ("<NJT>" ):
39303929 # NJT/NT always comes before offsets/shapes
3931- _nested_values = view_old_as_new (v , oldv )
3930+ nt = oldv
3931+ assert not v .numel ()
39323932 nt_lengths = None
39333933 del flat_dict [k ]
3934+ elif k [- 1 ].startswith ("<NJT_VALUES>" ):
3935+ nt_vaues = view_old_as_new (v , oldv )
3936+ del flat_dict [k ]
39343937 elif k [- 1 ].startswith ("<NJT_LENGTHS>" ):
39353938 nt_lengths = view_old_as_new (v , oldv )
39363939 del flat_dict [k ]
@@ -3939,15 +3942,16 @@ def _view_and_pad(tensor):
39393942 nt_offsets = view_old_as_new (v , oldv )
39403943 del flat_dict [k ]
39413944
3942- flat_dict [newk ] = torch .nested .nested_tensor_from_jagged (
3943- _nested_values ,
3944- offsets = nt_offsets ,
3945- lengths = nt_lengths ,
3945+ val = _rebuild_njt_from_njt (
3946+ nt , values = nt_vaues , offsets = nt_offsets , lengths = nt_lengths
39463947 )
3948+
3949+ flat_dict [newk ] = val
3950+
39473951 # delete the nested value to make sure that if there was an
39483952 # ordering mismatch we wouldn't be looking at the value key of
39493953 # another nested tensor.
3950- del _nested_values
3954+ del nt , nt_vaues , nt_offsets , nt_lengths
39513955 else :
39523956 flat_dict [k ] = view_old_as_new (v , oldv )
39533957
@@ -10459,9 +10463,52 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
1045910463 untyped_storage = storage_cast .untyped_storage ()
1046010464
1046110465 def set_ (x ):
10466+ if x .is_nested :
10467+ from torch ._subclasses .fake_tensor import FakeTensor
10468+ from torch ._subclasses .functional_tensor import FunctionalTensor
10469+ from torch .nested ._internal .nested_tensor import (
10470+ _tensor_symint_registry ,
10471+ NestedTensor ,
10472+ )
10473+ from torch .nested ._internal .ops import extract_kwargs
10474+
10475+ if x .layout != torch .jagged :
10476+ raise RuntimeError (
10477+ "to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10478+ "Please raise an issue on GitHub."
10479+ )
10480+ kwargs = extract_kwargs (x )
10481+ values = x ._values
10482+ lengths = x ._lengths
10483+ offsets = x ._offsets
10484+ kwargs ["offsets" ] = set_ (offsets )
10485+ if lengths is not None :
10486+ kwargs ["lengths" ] = set_ (lengths )
10487+ ragged_source = lengths
10488+ else :
10489+ ragged_source = offsets
10490+ new_thing = kwargs .get ("lengths" , kwargs .get ("offsets" ))
10491+ if isinstance (new_thing , (FakeTensor , FunctionalTensor )):
10492+ from torch ._subclasses .functional_tensor import (
10493+ mb_unwrap_functional_tensor ,
10494+ )
10495+
10496+ # Temporary hack until we have the union find
10497+ tgt = mb_unwrap_functional_tensor (new_thing )
10498+ src = mb_unwrap_functional_tensor (ragged_source )
10499+ tgt .nested_int_memo = src .nested_int_memo
10500+ else :
10501+ _tensor_symint_registry [new_thing ] = _tensor_symint_registry [
10502+ ragged_source
10503+ ]
10504+
10505+ return NestedTensor (
10506+ set_ (values ),
10507+ ** kwargs ,
10508+ )
1046210509 storage_offset = x .storage_offset ()
1046310510 stride = x .stride ()
10464- return torch . empty_like ( x , device = device ).set_ (
10511+ return x . new_empty ( 0 , device = device ).set_ (
1046510512 untyped_storage ,
1046610513 size = x .shape ,
1046710514 stride = stride ,
@@ -10473,7 +10520,14 @@ def set_(x):
1047310520 )
1047410521 result ._consolidated = {"storage" : storage_cast }
1047510522 if "metadata" in self ._consolidated :
10476- result ._consolidated ["metadata" ] = deepcopy (self ._consolidated ["metadata" ])
10523+ # faster than deepcopy
10524+ def copy_dict (d ):
10525+ return {
10526+ k : v if not isinstance (v , dict ) else copy_dict (v )
10527+ for k , v in d .items ()
10528+ }
10529+
10530+ result ._consolidated ["metadata" ] = copy_dict (self ._consolidated ["metadata" ])
1047710531 if non_blocking in (False , None ):
1047810532 if device .type == "cuda" and non_blocking is False :
1047910533 # sending to CUDA force sync
0 commit comments