1- import sys
1+ from hashlib import sha256
22from typing import cast
33
44from numba .core .extending import overload
55from numba .np .unsafe .ndarray import to_fixed_tuple
66
7+ from pytensor .link .numba .cache import compile_numba_function_src
78from pytensor .link .numba .dispatch import basic as numba_basic
8- from pytensor .link .numba .dispatch .basic import numba_funcify
9+ from pytensor .link .numba .dispatch .basic import (
10+ numba_funcify_and_cache_key ,
11+ register_funcify_and_cache_key ,
12+ )
913from pytensor .link .numba .dispatch .vectorize_codegen import (
1014 _jit_options ,
1115 _vectorized ,
1216 encode_literals ,
1317 store_core_outputs ,
1418)
15- from pytensor .link .utils import compile_function_src
1619from pytensor .tensor import TensorVariable , get_vector_length
1720from pytensor .tensor .blockwise import Blockwise , BlockwiseWithCoreShape
1821
1922
20- @numba_funcify . register (BlockwiseWithCoreShape )
23+ @register_funcify_and_cache_key (BlockwiseWithCoreShape )
2124def numba_funcify_Blockwise (op : BlockwiseWithCoreShape , node , ** kwargs ):
2225 [blockwise_node ] = op .fgraph .apply_nodes
2326 blockwise_op : Blockwise = blockwise_node .op
@@ -30,7 +33,7 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
3033 cast (tuple [TensorVariable ], node .inputs [:nin ]),
3134 propagate_unbatched_core_inputs = True ,
3235 )
33- core_op_fn = numba_funcify (
36+ core_op_fn , core_op_key = numba_funcify_and_cache_key (
3437 core_op ,
3538 node = core_node ,
3639 parent_node = node ,
@@ -58,36 +61,56 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
5861 src += ")"
5962
6063 to_tuple = numba_basic .numba_njit (
61- compile_function_src (
64+ compile_numba_function_src (
6265 src ,
6366 "to_tuple" ,
6467 global_env = {"to_fixed_tuple" : to_fixed_tuple },
65- ),
66- # cache=True leads to a numba.cloudpickle dump failure in Python 3.10
67- # May be fine in Python 3.11, but I didn't test. It was fine in 3.12
68- cache = sys .version_info >= (3 , 12 ),
69- )
70-
71- def blockwise_wrapper (* inputs_and_core_shapes ):
72- inputs , core_shapes = inputs_and_core_shapes [:nin ], inputs_and_core_shapes [nin :]
73- tuple_core_shapes = to_tuple (core_shapes )
74- return _vectorized (
75- core_op_fn ,
76- input_bc_patterns ,
77- output_bc_patterns ,
78- output_dtypes ,
79- inplace_pattern ,
80- (), # constant_inputs
81- inputs ,
82- tuple_core_shapes ,
83- None , # size
8468 )
69+ )
8570
8671 def blockwise (* inputs_and_core_shapes ):
87- raise NotImplementedError ("Non-jitted BlockwiseWithCoreShape not implemented" )
72+ raise NotImplementedError (
73+ "Numba implementation of Blockwise cannot be evaluated in Python (non-JIT) mode."
74+ )
8875
8976 @overload (blockwise , jit_options = _jit_options )
9077 def ov_blockwise (* inputs_and_core_shapes ):
91- return blockwise_wrapper
78+ def impl (* inputs_and_core_shapes ):
79+ inputs , core_shapes = (
80+ inputs_and_core_shapes [:nin ],
81+ inputs_and_core_shapes [nin :],
82+ )
83+ tuple_core_shapes = to_tuple (core_shapes )
84+ return _vectorized (
85+ core_op_fn ,
86+ input_bc_patterns ,
87+ output_bc_patterns ,
88+ output_dtypes ,
89+ inplace_pattern ,
90+ (), # constant_inputs
91+ inputs ,
92+ tuple_core_shapes ,
93+ None , # size
94+ )
95+
96+ return impl
9297
93- return blockwise
98+ if core_op_key is None :
99+ # We were told the core op cannot be cached
100+ blockwise_key = None
101+ else :
102+ blockwise_key = "_" .join (
103+ map (
104+ str ,
105+ (
106+ type (op ),
107+ type (blockwise_op ),
108+ tuple (blockwise_op .destroy_map .items ()),
109+ blockwise_op .signature ,
110+ input_bc_patterns ,
111+ core_op_key ,
112+ ),
113+ )
114+ )
115+ blockwise_key = sha256 (blockwise_key .encode ()).hexdigest ()
116+ return blockwise , blockwise_key
0 commit comments