44# LICENSE file in the root directory of this source tree.
55
66import argparse
7+ import time
8+ from typing import Any
79
810import pytest
911import torch
1012from packaging import version
1113
12- from tensordict import TensorDict
14+ from tensordict import tensorclass , TensorDict
15+ from tensordict .utils import logger as tensordict_logger
1316
1417TORCH_VERSION = version .parse (version .parse (torch .__version__ ).base_version )
1518
1619
17- @pytest .fixture
18- def td ():
19- return TensorDict (
20- {
21- str (i ): {str (j ): torch .randn (16 , 16 , device = "cpu" ) for j in range (16 )}
22- for i in range (16 )
23- },
24- batch_size = [16 ],
25- device = "cpu" ,
26- )
20+ @tensorclass
21+ class NJT :
22+ _values : torch .Tensor
23+ _offsets : torch .Tensor
24+ _lengths : torch .Tensor
25+ njt_shape : Any = None
26+
27+ @classmethod
28+ def from_njt (cls , njt_tensor ):
29+ return cls (
30+ _values = njt_tensor ._values ,
31+ _offsets = njt_tensor ._offsets ,
32+ _lengths = njt_tensor ._lengths ,
33+ njt_shape = njt_tensor .size (0 ),
34+ ).clone ()
35+
36+
37+ @pytest .fixture (autouse = True , scope = "function" )
38+ def empty_compiler_cache ():
39+ torch .compiler .reset ()
40+ yield
2741
2842
2943def _make_njt ():
@@ -34,14 +48,29 @@ def _make_njt():
3448 )
3549
3650
37- @pytest .fixture
38- def njt_td ():
51+ def _njt_td ():
3952 return TensorDict (
40- {str (i ): {str (j ): _make_njt () for j in range (32 )} for i in range (32 )},
53+ # {str(i): {str(j): _make_njt() for j in range(32)} for i in range(32)},
54+ {str (i ): _make_njt () for i in range (128 )},
4155 device = "cpu" ,
4256 )
4357
4458
59+ @pytest .fixture
60+ def njt_td ():
61+ return _njt_td ()
62+
63+
64+ @pytest .fixture
65+ def td ():
66+ njtd = _njt_td ()
67+ for k0 , v0 in njtd .items ():
68+ njtd [k0 ] = NJT .from_njt (v0 )
69+ # for k1, v1 in v0.items():
70+ # njtd[k0, k1] = NJT.from_njt(v1)
71+ return njtd
72+
73+
4574@pytest .fixture
4675def default_device ():
4776 if torch .cuda .is_available ():
@@ -52,22 +81,139 @@ def default_device():
5281 pytest .skip ("CUDA/MPS is not available" )
5382
5483
55- @pytest .mark .parametrize ("consolidated" , [False , True ])
84+ @pytest .mark .parametrize (
85+ "compile_mode,num_threads" ,
86+ [
87+ [False , None ],
88+ # [False, 4],
89+ # [False, 16],
90+ ["default" , None ],
91+ ["reduce-overhead" , None ],
92+ ],
93+ )
94+ @pytest .mark .skipif (
95+ TORCH_VERSION < version .parse ("2.5.0" ), reason = "requires torch>=2.5"
96+ )
97+ class TestConsolidate :
98+ def test_consolidate (self , benchmark , td , compile_mode , num_threads ):
99+ tensordict_logger .info (f"td size { td .bytes () / 1024 / 1024 :.2f} Mb" )
100+
101+ def consolidate (td , num_threads ):
102+ return td .consolidate (num_threads = num_threads )
103+
104+ if compile_mode :
105+ consolidate = torch .compile (
106+ consolidate , mode = compile_mode , dynamic = True , fullgraph = True
107+ )
108+
109+ t0 = time .time ()
110+ consolidate (td , num_threads = num_threads )
111+ elapsed = time .time () - t0
112+ tensordict_logger .info (f"elapsed time first call: { elapsed :.2f} sec" )
113+
114+ for _ in range (3 ):
115+ consolidate (td , num_threads = num_threads )
116+
117+ benchmark (consolidate , td , num_threads )
118+
119+ def test_to_njt (self , benchmark , njt_td , compile_mode , num_threads ):
120+ tensordict_logger .info (f"njtd size { njt_td .bytes () / 1024 / 1024 :.2f} Mb" )
121+
122+ def consolidate (td , num_threads ):
123+ return td .consolidate (num_threads = num_threads )
124+
125+ if compile_mode :
126+ consolidate = torch .compile (consolidate , mode = compile_mode , dynamic = True )
127+
128+ for _ in range (3 ):
129+ consolidate (njt_td , num_threads = num_threads )
130+
131+ benchmark (consolidate , njt_td , num_threads )
132+
133+
134+ @pytest .mark .parametrize (
135+ "consolidated,compile_mode,num_threads" ,
136+ [
137+ [False , False , None ],
138+ [True , False , None ],
139+ ["within" , False , None ],
140+ # [True, False, 4],
141+ # [True, False, 16],
142+ [True , "default" , None ],
143+ ],
144+ )
56145@pytest .mark .skipif (
57146 TORCH_VERSION < version .parse ("2.5.1" ), reason = "requires torch>=2.5"
58147)
59148class TestTo :
60- def test_to (self , benchmark , consolidated , td , default_device ):
61- if consolidated :
62- td = td .consolidate ()
63- benchmark (lambda : td .to (default_device ))
149+ def test_to (
150+ self , benchmark , consolidated , td , default_device , compile_mode , num_threads
151+ ):
152+ tensordict_logger .info (f"td size { td .bytes () / 1024 / 1024 :.2f} Mb" )
153+ pin_mem = default_device .type == "cuda"
154+ if consolidated is True :
155+ td = td .consolidate (pin_memory = pin_mem )
156+
157+ if consolidated == "within" :
158+
159+ def to (td , num_threads ):
160+ return td .consolidate (pin_memory = pin_mem ).to (
161+ default_device , num_threads = num_threads
162+ )
163+
164+ else :
64165
65- def test_to_njt (self , benchmark , consolidated , njt_td , default_device ):
66- if consolidated :
67- njt_td = njt_td .consolidate ()
68- benchmark (lambda : njt_td .to (default_device ))
166+ def to (td , num_threads ):
167+ return td .to (default_device , num_threads = num_threads )
168+
169+ if compile_mode :
170+ to = torch .compile (to , mode = compile_mode , dynamic = True )
171+
172+ for _ in range (3 ):
173+ to (td , num_threads = num_threads )
174+
175+ benchmark (to , td , num_threads )
176+
177+ def test_to_njt (
178+ self , benchmark , consolidated , njt_td , default_device , compile_mode , num_threads
179+ ):
180+ tensordict_logger .info (f"njtd size { njt_td .bytes () / 1024 / 1024 :.2f} Mb" )
181+ pin_mem = default_device .type == "cuda"
182+ if consolidated is True :
183+ njt_td = njt_td .consolidate (pin_memory = pin_mem )
184+
185+ if consolidated == "within" :
186+
187+ def to (td , num_threads ):
188+ return td .consolidate (pin_memory = pin_mem ).to (
189+ default_device , num_threads = num_threads
190+ )
191+
192+ else :
193+
194+ def to (td , num_threads ):
195+ return td .to (default_device , num_threads = num_threads )
196+
197+ if compile_mode :
198+ to = torch .compile (to , mode = compile_mode , dynamic = True )
199+
200+ for _ in range (3 ):
201+ to (njt_td , num_threads = num_threads )
202+
203+ benchmark (to , njt_td , num_threads )
69204
70205
71206if __name__ == "__main__" :
72207 args , unknown = argparse .ArgumentParser ().parse_known_args ()
73- pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" ] + unknown )
208+ pytest .main (
209+ [
210+ __file__ ,
211+ "--capture" ,
212+ "no" ,
213+ "--exitfirst" ,
214+ "--benchmark-group-by" ,
215+ "func" ,
216+ "-vvv" ,
217+ ]
218+ + unknown
219+ )
0 commit comments