55from __future__ import annotations
66
77import copyreg
8+ import queue
89from multiprocessing .reduction import ForkingPickler
910
1011import torch
2122
2223
2324def _rebuild_tensordict_files (flat_key_values , metadata_dict , is_shared : bool = False ):
25+ _nt_values_and_keys = queue .Queue ()
26+ _nt_lengths = queue .Queue ()
27+ _nt_offsets = queue .Queue ()
28+
2429 def from_metadata (metadata = metadata_dict , prefix = None ):
30+ metadata = dict (metadata )
31+
32+ _ = metadata .pop ("njt_values_start" , None )
33+ _ = metadata .pop ("njt_lengths_start" , None )
34+ _ = metadata .pop ("njt_offsets_start" , None )
35+
2536 non_tensor = metadata .pop ("non_tensors" )
2637 leaves = metadata .pop ("leaves" )
2738 cls = metadata .pop ("cls" )
@@ -36,22 +47,21 @@ def from_metadata(metadata=metadata_dict, prefix=None):
3647 total_key = (key ,) if prefix is None else prefix + (key ,)
3748 if total_key [- 1 ].startswith ("<NJT>" ):
3849 nested_values = flat_key_values [total_key ]
39- nested_lengths = None
50+ total_key = total_key [:- 1 ] + total_key [- 1 ].replace ("<NJT>" , "" )
51+ _nt_values_and_keys .put ((nested_values , total_key ))
4052 continue
4153 if total_key [- 1 ].startswith ("<NJT_LENGTHS>" ):
4254 nested_lengths = flat_key_values [total_key ]
55+ _nt_lengths .put (nested_lengths )
4356 continue
4457 elif total_key [- 1 ].startswith ("<NJT_OFFSETS" ):
4558 offsets = flat_key_values [total_key ]
46- key = key .replace ("<NJT_OFFSETS>" , "" )
47- value = torch .nested .nested_tensor_from_jagged (
48- nested_values , offsets = offsets , lengths = nested_lengths
49- )
50- del nested_values
51- del nested_lengths
59+ _nt_offsets .put (offsets )
60+ continue
5261 else :
5362 value = flat_key_values [total_key ]
5463 d [key ] = value
64+
5565 for k , v in metadata .items ():
5666 # Each remaining key is a tuple pointing to a sub-tensordict
5767 d [k ] = from_metadata (
@@ -64,7 +74,18 @@ def from_metadata(metadata=metadata_dict, prefix=None):
6474 # result._is_shared = is_shared
6575 return result
6676
67- return from_metadata ()
77+ result = from_metadata ()
78+ # Then assign the nested tensors
79+ while not _nt_values_and_keys .empty ():
80+ vals , key = _nt_values_and_keys .get ()
81+ lengths = _nt_lengths .get ()
82+ offsets = _nt_offsets .get ()
83+ value = torch .nested .nested_tensor_from_jagged (
84+ vals , offsets = offsets , lengths = lengths
85+ )
86+ result ._set_tuple (key , value , inplace = False , validated = True )
87+
88+ return result
6889
6990
7091def _rebuild_tensordict_files_shared (flat_key_values , metadata_dict ):
@@ -75,9 +96,18 @@ def _rebuild_tensordict_files_consolidated(
7596 metadata ,
7697 storage ,
7798):
99+ _nt_values_and_keys = queue .Queue ()
100+ _nt_lengths = queue .Queue ()
101+ _nt_offsets = queue .Queue ()
102+
78103 def from_metadata (metadata = metadata , prefix = None ):
79104 consolidated = {"storage" : storage , "metadata" : metadata }
80105 metadata = dict (metadata )
106+
107+ _ = metadata .pop ("njt_values_start" , None )
108+ _ = metadata .pop ("njt_lengths_start" , None )
109+ _ = metadata .pop ("njt_offsets_start" , None )
110+
81111 non_tensor = metadata .pop ("non_tensors" )
82112 leaves = metadata .pop ("leaves" )
83113 cls = metadata .pop ("cls" )
@@ -99,31 +129,45 @@ def from_metadata(metadata=metadata, prefix=None):
99129 value = value [: local_shape .numel ()]
100130 value = value .view (local_shape )
101131 if key .startswith ("<NJT>" ):
102- nested_values = value
103- nested_lengths = None
132+ key = key .replace ("<NJT>" , "" )
133+ if prefix :
134+ total_key = prefix + (key ,)
135+ else :
136+ total_key = (key ,)
137+ _nt_values_and_keys .put ((value , total_key ))
104138 continue
105139 elif key .startswith ("<NJT_LENGTHS>" ):
106- nested_lengths = value
140+ _nt_lengths . put ( value )
107141 continue
108142 elif key .startswith ("<NJT_OFFSETS>" ):
109- offsets = value
110- value = torch .nested .nested_tensor_from_jagged (
111- nested_values , offsets = offsets , lengths = nested_lengths
112- )
113- key = key .replace ("<NJT_OFFSETS>" , "" )
143+ _nt_offsets .put (value )
144+ if _nt_offsets .qsize () > _nt_lengths .qsize ():
145+ _nt_lengths .put (None )
146+ continue
114147 d [key ] = value
115- for k , v in metadata .items ():
148+ for key , val in metadata .items ():
116149 # Each remaining key is a tuple pointing to a sub-tensordict
117- d [k ] = from_metadata (
118- v , prefix = prefix + (k ,) if prefix is not None else (k ,)
150+ d [key ] = from_metadata (
151+ val , prefix = prefix + (key ,) if prefix is not None else (key ,)
119152 )
120153 result = CLS_MAP [cls ]._from_dict_validated (d , ** cls_metadata )
121154 if is_locked :
122155 result = result .lock_ ()
123156 result ._consolidated = consolidated
124157 return result
125158
126- return from_metadata ()
159+ result = from_metadata ()
160+ # Then assign the nested tensors
161+ while not _nt_values_and_keys .empty ():
162+ vals , key = _nt_values_and_keys .get ()
163+ lengths = _nt_lengths .get ()
164+ offsets = _nt_offsets .get ()
165+ value = torch .nested .nested_tensor_from_jagged (
166+ vals , offsets = offsets , lengths = lengths
167+ )
168+ result ._set_tuple (key , value , inplace = False , validated = True )
169+
170+ return result
127171
128172
129173def _make_td (cls , state ):
0 commit comments