44# LICENSE file in the root directory of this source tree.
55
66import argparse
7+ from typing import Any
78
89import pytest
910import torch
1011from packaging import version
1112
12- from tensordict import TensorDict
13+ from tensordict import tensorclass , TensorDict
14+ from tensordict .utils import logger as tensordict_logger
1315
1416TORCH_VERSION = version .parse (version .parse (torch .__version__ ).base_version )
1517
1618
17- @pytest .fixture
18- def td ():
19- return TensorDict (
20- {
21- str (i ): {str (j ): torch .randn (16 , 16 , device = "cpu" ) for j in range (16 )}
22- for i in range (16 )
23- },
24- batch_size = [16 ],
25- device = "cpu" ,
26- )
19+ @tensorclass
20+ class NJT :
21+ _values : torch .Tensor
22+ _offsets : torch .Tensor
23+ _lengths : torch .Tensor
24+ njt_shape : Any = None
25+
26+ @classmethod
27+ def from_njt (cls , njt_tensor ):
28+ return NJT (
29+ _values = njt_tensor ._values ,
30+ _offsets = njt_tensor ._offsets ,
31+ _lengths = njt_tensor ._lengths ,
32+ njt_shape = njt_tensor .size (0 ),
33+ )
34+
35+
36+ @pytest .fixture (autouse = True , scope = "function" )
37+ def empty_compiler_cache ():
38+ torch ._dynamo .reset_code_caches ()
39+ yield
2740
2841
2942def _make_njt ():
@@ -34,14 +47,27 @@ def _make_njt():
3447 )
3548
3649
37- @pytest .fixture
38- def njt_td ():
50+ def _njt_td ():
3951 return TensorDict (
4052 {str (i ): {str (j ): _make_njt () for j in range (32 )} for i in range (32 )},
4153 device = "cpu" ,
4254 )
4355
4456
57+ @pytest .fixture
58+ def njt_td ():
59+ return _njt_td ()
60+
61+
62+ @pytest .fixture
63+ def td ():
64+ njtd = _njt_td ()
65+ for k0 , v0 in njtd .items ():
66+ for k1 , v1 in v0 .items ():
67+ njtd [k0 , k1 ] = NJT .from_njt (v1 )
68+ return njtd
69+
70+
4571@pytest .fixture
4672def default_device ():
4773 if torch .cuda .is_available ():
@@ -52,22 +78,81 @@ def default_device():
5278 pytest .skip ("CUDA/MPS is not available" )
5379
5480
55- @pytest .mark .parametrize ("consolidated" , [False , True ])
81+ @pytest .mark .parametrize (
82+ "consolidated,compile_mode,num_threads" ,
83+ [
84+ [False , False , None ],
85+ [True , False , None ],
86+ ["within" , False , None ],
87+ # [True, False, 4],
88+ # [True, False, 16],
89+ # [True, "default", None],
90+ ],
91+ )
5692@pytest .mark .skipif (
5793 TORCH_VERSION < version .parse ("2.5.0" ), reason = "requires torch>=2.5"
5894)
5995class TestTo :
60- def test_to (self , benchmark , consolidated , td , default_device ):
61- if consolidated :
62- td = td .consolidate ()
63- benchmark (lambda : td .to (default_device ))
96+ def test_to (
97+ self , benchmark , consolidated , td , default_device , compile_mode , num_threads
98+ ):
99+ tensordict_logger .info (f"td size { td .bytes () / 1024 / 1024 :.2f} Mb" )
100+ pin_mem = default_device .type == "cuda"
101+ if consolidated is True :
102+ td = td .consolidate (pin_memory = pin_mem )
103+
104+ if consolidated == "within" :
105+
106+ def to (td , num_threads ):
107+ return td .consolidate (pin_memory = pin_mem ).to (
108+ default_device , num_threads = num_threads
109+ )
110+
111+ else :
112+
113+ def to (td , num_threads ):
114+ return td .to (default_device , num_threads = num_threads )
115+
116+ if compile_mode :
117+ to = torch .compile (to , mode = compile_mode )
118+
119+ for _ in range (3 ):
120+ to (td , num_threads = num_threads )
121+
122+ benchmark (to , td , num_threads )
64123
65- def test_to_njt (self , benchmark , consolidated , njt_td , default_device ):
66- if consolidated :
67- njt_td = njt_td .consolidate ()
68- benchmark (lambda : njt_td .to (default_device ))
124+ def test_to_njt (
125+ self , benchmark , consolidated , njt_td , default_device , compile_mode , num_threads
126+ ):
127+ tensordict_logger .info (f"njtd size { njt_td .bytes () / 1024 / 1024 :.2f} Mb" )
128+ pin_mem = default_device .type == "cuda"
129+ if consolidated is True :
130+ njt_td = njt_td .consolidate (pin_memory = pin_mem )
131+
132+ if consolidated == "within" :
133+
134+ def to (td , num_threads ):
135+ return td .consolidate (pin_memory = pin_mem ).to (
136+ default_device , num_threads = num_threads
137+ )
138+
139+ else :
140+
141+ def to (td , num_threads ):
142+ return td .to (default_device , num_threads = num_threads )
143+
144+ if compile_mode :
145+ to = torch .compile (to , mode = compile_mode )
146+
147+ for _ in range (3 ):
148+ to (njt_td , num_threads = num_threads )
149+
150+ benchmark (to , njt_td , num_threads )
69151
70152
71153if __name__ == "__main__" :
72154 args , unknown = argparse .ArgumentParser ().parse_known_args ()
73- pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" ] + unknown )
155+ pytest .main (
156+ [__file__ , "--capture" , "no" , "--exitfirst" , "--benchmark-group-by" , "func" ]
157+ + unknown
158+ )
0 commit comments