44# LICENSE file in the root directory of this source tree.
55
66import argparse
7+ import inspect
78
89import pytest
910import torch
1011import torch .nn as nn
1112
1213from tensordict import TensorDict
13- from tensordict .nn import TensorDictModule , TensorDictSequential
14+ from tensordict .nn import TensorDictModule as Mod , TensorDictSequential as Seq
1415from tensordict .prototype .fx import symbolic_trace
1516
1617
18+ def test_fx ():
19+ seq = Seq (
20+ Mod (lambda x : x + 1 , in_keys = ["x" ], out_keys = ["y" ]),
21+ Mod (lambda x , y : (x * y ).sqrt (), in_keys = ["x" , "y" ], out_keys = ["z" ]),
22+ Mod (lambda z , x : z - z , in_keys = ["z" , "x" ], out_keys = ["a" ]),
23+ )
24+ symbolic_trace (seq )
25+
26+
27+ class TestModule (torch .nn .Module ):
28+ def __init__ (self ):
29+ super ().__init__ ()
30+ self .linear = torch .nn .Linear (2 , 2 )
31+
32+ def forward (self , td : TensorDict ) -> torch .Tensor :
33+ vals = td .values () # pyre-ignore[6]
34+ return torch .cat ([val ._values for val in vals ], dim = 0 )
35+
36+
37+ def test_td_scripting () -> None :
38+ for cls in (TensorDict ,):
39+ for name in dir (cls ):
40+ method = inspect .getattr_static (cls , name )
41+ if isinstance (method , classmethod ):
42+ continue
43+ elif isinstance (method , staticmethod ):
44+ continue
45+ elif not callable (method ):
46+ continue
47+ elif not name .startswith ("__" ) or name in ("__init__" , "__setitem__" ):
48+ setattr (cls , name , torch .jit .unused (method ))
49+
50+ m = TestModule ()
51+ td = TensorDict (
52+ a = torch .nested .nested_tensor ([torch .ones ((1 ,))], layout = torch .jagged )
53+ )
54+ m (td )
55+ m = torch .jit .script (m , example_inputs = (td ,))
56+ m .code
57+
58+
1759def test_tensordictmodule_trace_consistency ():
1860 class Net (nn .Module ):
1961 def __init__ (self ):
@@ -24,7 +66,7 @@ def forward(self, x):
2466 logits = self .linear (x )
2567 return logits , torch .sigmoid (logits )
2668
27- module = TensorDictModule (
69+ module = Mod (
2870 Net (),
2971 in_keys = ["input" ],
3072 out_keys = [("outputs" , "logits" ), ("outputs" , "probabilities" )],
@@ -63,15 +105,13 @@ class Masker(nn.Module):
63105 def forward (self , x , mask ):
64106 return torch .softmax (x * mask , dim = 1 )
65107
66- net = TensorDictModule (
67- Net (), in_keys = [("input" , "x" )], out_keys = [("intermediate" , "x" )]
68- )
69- masker = TensorDictModule (
108+ net = Mod (Net (), in_keys = [("input" , "x" )], out_keys = [("intermediate" , "x" )])
109+ masker = Mod (
70110 Masker (),
71111 in_keys = [("intermediate" , "x" ), ("input" , "mask" )],
72112 out_keys = [("output" , "probabilities" )],
73113 )
74- module = TensorDictSequential (net , masker )
114+ module = Seq (net , masker )
75115 graph_module = symbolic_trace (module )
76116
77117 tensordict = TensorDict (
@@ -120,13 +160,11 @@ def forward(self, x):
120160 module2 = Net (50 , 40 )
121161 module3 = Output (40 , 10 )
122162
123- tdmodule1 = TensorDictModule (module1 , ["input" ], ["x" ])
124- tdmodule2 = TensorDictModule (module2 , ["x" ], ["x" ])
125- tdmodule3 = TensorDictModule (module3 , ["x" ], ["probabilities" ])
163+ tdmodule1 = Mod (module1 , ["input" ], ["x" ])
164+ tdmodule2 = Mod (module2 , ["x" ], ["x" ])
165+ tdmodule3 = Mod (module3 , ["x" ], ["probabilities" ])
126166
127- tdmodule = TensorDictSequential (
128- TensorDictSequential (tdmodule1 , tdmodule2 ), tdmodule3
129- )
167+ tdmodule = Seq (Seq (tdmodule1 , tdmodule2 ), tdmodule3 )
130168 graph_module = symbolic_trace (tdmodule )
131169
132170 tensordict = TensorDict ({"input" : torch .rand (32 , 100 )}, [32 ])
0 commit comments