1+ import copy
12from typing import Any
23
34from pytensor .graph .basic import Variable
45from pytensor .link .basic import JITLinker
6+ from pytensor .link .utils import unique_name_generator
57
68
79class PytorchLinker (JITLinker ):
810 """A `Linker` that compiles NumPy-based operations using torch.compile."""
911
12+ def __init__ (self , * args , ** kwargs ):
13+ super ().__init__ (* args , ** kwargs )
14+ self .gen_functors = []
15+
1016 def input_filter (self , inp : Any ) -> Any :
1117 from pytensor .link .pytorch .dispatch import pytorch_typify
1218
@@ -18,14 +24,68 @@ def output_filter(self, var: Variable, out: Any) -> Any:
1824 def fgraph_convert (self , fgraph , input_storage , storage_map , ** kwargs ):
1925 from pytensor .link .pytorch .dispatch import pytorch_funcify
2026
27+ # We want to have globally unique names
28+ # across the entire pytensor graph, not
29+ # just the subgraph
30+ generator = unique_name_generator (["torch_linker" ])
31+
32+ # Ensure that torch is aware of the generated
33+ # code so we can compile without graph breaks
34+ def conversion_func_register (* args , ** kwargs ):
35+ functor = pytorch_funcify (* args , ** kwargs )
36+ name = kwargs ["unique_name" ](functor )
37+ self .gen_functors .append ((f"_{ name } " , functor ))
38+ return functor
39+
40+ built_kwargs = {
41+ "unique_name" : generator ,
42+ "conversion_func" : conversion_func_register ,
43+ ** kwargs ,
44+ }
2145 return pytorch_funcify (
22- fgraph , input_storage = input_storage , storage_map = storage_map , ** kwargs
46+ fgraph , input_storage = input_storage , storage_map = storage_map , ** built_kwargs
2347 )
2448
2549 def jit_compile (self , fn ):
2650 import torch
2751
28- return torch .compile (fn )
52+ class wrapper :
53+ """
54+ Pytorch would fail compiling our method when trying
55+ to resolve some of the methods returned from dispatch
56+ calls. We want to be careful to not leak the methods,
57+ so this class just holds them and provisions the expected
58+ location accordingly
59+
60+ https://discuss.pytorch.org/t/closures-are-being-gcd-and-causing-failures-to-compile/213319
61+ """
62+
63+ def __init__ (self , fn , gen_functors ):
64+ self .fn = torch .compile (fn )
65+ self .gen_functors = copy .copy (gen_functors )
66+
67+ def __call__ (self , * args , ** kwargs ):
68+ import pytensor .link .utils
69+
70+ # set attrs
71+ for n , fn in self .gen_functors :
72+ setattr (pytensor .link .utils , n [1 :], fn )
73+
74+ res = self .fn (* args , ** kwargs )
75+
76+ # unset attrs
77+ for n , _ in self .gen_functors :
78+ if getattr (pytensor .link .utils , n [1 :], False ):
79+ delattr (pytensor .link .utils , n [1 :])
80+
81+ return res
82+
83+ def __del__ (self ):
84+ del self .gen_functors
85+
86+ res = wrapper (fn , self .gen_functors )
87+ self .gen_functors = []
88+ return res
2989
3090 def create_thunk_inputs (self , storage_map ):
3191 thunk_inputs = []
0 commit comments