Skip to content

Commit 943f18d

Browse files
committed
Saner defaults
1 parent 980ae7c commit 943f18d

File tree

10 files changed

+84
-37
lines changed

10 files changed

+84
-37
lines changed

pytensor/link/numba/cache.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
import weakref
23
from collections.abc import Callable
34
from functools import singledispatch, wraps
@@ -25,8 +26,6 @@ def __init__(self, py_func, py_file, hash):
2526
self._py_func = py_func
2627
self._py_file = py_file
2728
self._hash = hash
28-
# src_hash = hash(pytensor_loader._module_sources[self._py_file])
29-
# self._hash = hash((src_hash, py_file, pytensor.__version__))
3029

3130
def ensure_cache_path(self):
3231
pass
@@ -185,27 +184,34 @@ def dispatch_func_wrapper(*args, **kwargs):
185184
return decorator
186185

187186

188-
def numba_njit_and_cache(op, node, **kwargs):
189-
jitable_func, key = numba_funcify_and_cache_key(op, node=node, **kwargs)
187+
def numba_njit_and_cache(op, *args, **kwargs):
188+
jitable_func, key = numba_funcify_and_cache_key(op, *args, **kwargs)
190189

191190
if key is not None:
192191
# To force numba to use our cache, we must compile the function so that any closure
193192
# becomes a global variable...
194193
op_name = op.__class__.__name__
195-
cached_func = compile_and_cache_numba_function_src(
194+
cached_func = compile_numba_function_src(
196195
src=f"def {op_name}(*args): return jitable_func(*args)",
197196
function_name=op_name,
198-
global_env=globals() | dict(jitable_func=jitable_func),
197+
global_env=globals() | {"jitable_func": jitable_func},
199198
cache_key=key,
200199
)
201-
return numba_njit(cached_func, final_function=True, cache=True)
200+
return numba_njit(cached_func, final_function=True, cache=True), key
202201
else:
202+
if config.numba__cache and config.compiler_verbose:
203+
warnings.warn(
204+
f"Custom numba cache disabled for {op} of type {type(op)}. "
205+
f"Even if the function is cached by numba, larger graphs using this function cannot be cached.\n"
206+
"To enable custom caching, register a numba_funcify_and_cache_key implementation for this Op, with a proper cache key."
207+
)
208+
203209
return numba_njit(
204210
lambda *args: jitable_func(*args), final_function=True, cache=False
205-
)
211+
), None
206212

207213

208-
def compile_and_cache_numba_function_src(
214+
def compile_numba_function_src(
209215
src: str,
210216
function_name: str,
211217
global_env: dict[Any, Any] | None = None,

pytensor/link/numba/dispatch/basic.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import warnings
44
from collections.abc import Callable
55
from functools import singledispatch
6+
from hashlib import sha256
7+
from pickle import dumps
68

79
import numba
810
import numpy as np
@@ -17,11 +19,13 @@
1719
from pytensor.compile.builders import OpFromGraph
1820
from pytensor.compile.function.types import add_supervisor_to_fgraph
1921
from pytensor.compile.ops import DeepCopyOp
22+
from pytensor.graph import Constant
2023
from pytensor.graph.fg import FunctionGraph
2124
from pytensor.graph.op import Op
2225
from pytensor.ifelse import IfElse
2326
from pytensor.link.numba.cache import (
2427
numba_njit_and_cache,
28+
register_funcify_and_cache_key,
2529
register_funcify_default_op_cache_key,
2630
)
2731
from pytensor.link.numba.compile import (
@@ -226,21 +230,52 @@ def numba_funcify_fallback(
226230
return generate_fallback_impl(op, node, storage_map, **kwargs)
227231

228232

229-
@numba_funcify.register(FunctionGraph)
233+
def key_for_constant(data):
234+
"""Create a cache key for a constant value."""
235+
# TODO: This is just a placeholder
236+
if isinstance(data, (int | float | bool | type(None))):
237+
return str(data)
238+
try:
239+
# For NumPy arrays
240+
return sha256(data.tobytes()).hexdigest()
241+
except AttributeError:
242+
# Fallback for other types
243+
return sha256(dumps(data)).hexdigest()
244+
245+
246+
@register_funcify_and_cache_key(FunctionGraph)
230247
def numba_funcify_FunctionGraph(
231248
fgraph,
232249
node=None,
233250
fgraph_name="numba_funcified_fgraph",
234251
**kwargs,
235252
):
236-
# TODO: Create hash key for whole graph
237-
return fgraph_to_python(
253+
cache_keys = []
254+
255+
def op_conversion_and_key_collection(*args, **kwargs):
256+
func, key = numba_njit_and_cache(*args, **kwargs)
257+
cache_keys.append(key)
258+
return func
259+
260+
def type_conversion_and_key_collection(value, variable, **kwargs):
261+
if isinstance(variable, Constant):
262+
cache_keys.append(key_for_constant(value))
263+
return numba_typify(value, variable=variable, **kwargs)
264+
265+
py_func = fgraph_to_python(
238266
fgraph,
239-
op_conversion_fn=numba_njit_and_cache,
240-
type_conversion_fn=numba_typify,
267+
op_conversion_fn=op_conversion_and_key_collection,
268+
type_conversion_fn=type_conversion_and_key_collection,
241269
fgraph_name=fgraph_name,
242270
**kwargs,
243271
)
272+
if any(key is None for key in cache_keys):
273+
fgraph_key = None
274+
else:
275+
fgraph_key = sha256(
276+
str((tuple(cache_keys), len(fgraph.inputs), len(fgraph.outputs))).encode()
277+
).hexdigest()
278+
return py_func, fgraph_key
244279

245280

246281
@numba_funcify.register(OpFromGraph)

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pytensor.graph.op import Op
1111
from pytensor.link.numba.cache import (
12-
compile_and_cache_numba_function_src,
12+
compile_numba_function_src,
1313
numba_funcify_and_cache_key,
1414
register_funcify_and_cache_key,
1515
register_funcify_default_op_cache_key,
@@ -239,7 +239,7 @@ def {careduce_fn_name}(x):
239239
careduce_def_src += "\n\n"
240240
careduce_def_src += indent(f"return {return_obj}", " " * 4)
241241

242-
careduce_fn = compile_and_cache_numba_function_src(
242+
careduce_fn = compile_numba_function_src(
243243
careduce_def_src, careduce_fn_name, {**globals(), **global_env}
244244
)
245245

pytensor/link/numba/dispatch/scalar.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from pytensor.compile.ops import TypeCastingOp
77
from pytensor.graph.basic import Variable
88
from pytensor.link.numba.cache import (
9-
compile_and_cache_numba_function_src,
9+
compile_numba_function_src,
10+
numba_funcify_and_cache_key,
1011
register_funcify_and_cache_key,
1112
)
1213
from pytensor.link.numba.compile import (
@@ -138,7 +139,7 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
138139
return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype)
139140
"""
140141

141-
scalar_op_fn = compile_and_cache_numba_function_src(
142+
scalar_op_fn = compile_numba_function_src(
142143
scalar_op_src,
143144
scalar_op_fn_name,
144145
{**globals(), **global_env},
@@ -172,7 +173,7 @@ def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op:
172173
def {binary_op_name}({input_signature}):
173174
return {output_expr}
174175
"""
175-
nary_fn = compile_and_cache_numba_function_src(nary_src, binary_op_name, globals())
176+
nary_fn = compile_numba_function_src(nary_src, binary_op_name, globals())
176177

177178
return nary_fn
178179

@@ -234,8 +235,15 @@ def clip(x, min_val, max_val):
234235
def numba_funcify_Composite(op, node, **kwargs):
235236
_ = kwargs.pop("storage_map", None)
236237

237-
composite_fn = numba_njit(numba_funcify(op.fgraph, squeeze_output=True, **kwargs))
238-
return composite_fn, str(tuple(type(node.op) for node in op.fgraph.toposort()))
238+
composite_fn, fgraph_key = numba_funcify_and_cache_key(
239+
op.fgraph, squeeze_output=True, **kwargs
240+
)
241+
composite_fn = numba_njit(composite_fn)
242+
if fgraph_key is None:
243+
composite_key = None
244+
else:
245+
composite_key = sha256(str((type(op), fgraph_key)).encode()).hexdigest()
246+
return composite_fn, composite_key
239247

240248

241249
@register_funcify_and_cache_key(Second)

pytensor/link/numba/dispatch/scan.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytensor import In
88
from pytensor.compile.function.types import add_supervisor_to_fgraph
99
from pytensor.compile.mode import NUMBA, get_mode
10-
from pytensor.link.numba.cache import compile_and_cache_numba_function_src
10+
from pytensor.link.numba.cache import compile_numba_function_src
1111
from pytensor.link.numba.compile import (
1212
create_arg_string,
1313
create_tuple_string,
@@ -443,12 +443,10 @@ def scan({", ".join(outer_in_names)}):
443443
}
444444
global_env["np"] = np
445445

446-
scan_op_fn = compile_and_cache_numba_function_src(
446+
scan_op_fn = compile_numba_function_src(
447447
scan_op_src,
448448
"scan",
449449
{**globals(), **global_env},
450-
# We can't cache until we can hash FunctionGraph
451-
cache_key=None,
452450
)
453451

454-
return numba_njit(scan_op_fn, boundscheck=False), None
452+
return numba_njit(scan_op_fn, boundscheck=False)

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pytensor.graph import Type
44
from pytensor.link.numba.cache import (
5-
compile_and_cache_numba_function_src,
5+
compile_numba_function_src,
66
register_funcify_default_op_cache_key,
77
)
88
from pytensor.link.numba.compile import numba_njit
@@ -99,7 +99,7 @@ def {function_name}({", ".join(input_names)}):
9999
return np.asarray(z)
100100
"""
101101

102-
func = compile_and_cache_numba_function_src(
102+
func = compile_numba_function_src(
103103
subtensor_def_src,
104104
function_name=function_name,
105105
global_env=globals() | {"np": np},

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55
from pytensor.link.numba.cache import (
6-
compile_and_cache_numba_function_src,
6+
compile_numba_function_src,
77
register_funcify_default_op_cache_key,
88
)
99
from pytensor.link.numba.compile import (
@@ -56,7 +56,7 @@ def allocempty({", ".join(shape_var_names)}):
5656
return np.empty(scalar_shape, dtype)
5757
"""
5858

59-
alloc_fn = compile_and_cache_numba_function_src(
59+
alloc_fn = compile_numba_function_src(
6060
alloc_def_src, "allocempty", {**globals(), **global_env}
6161
)
6262

@@ -100,7 +100,7 @@ def alloc(val, {", ".join(shape_var_names)}):
100100
res[...] = val
101101
return res
102102
"""
103-
alloc_fn = compile_and_cache_numba_function_src(
103+
alloc_fn = compile_numba_function_src(
104104
alloc_def_src,
105105
"alloc",
106106
{**globals(), **global_env},
@@ -223,7 +223,7 @@ def makevector({", ".join(input_names)}):
223223
return np.array({create_list_string(input_names)}, dtype=dtype)
224224
"""
225225

226-
makevector_fn = compile_and_cache_numba_function_src(
226+
makevector_fn = compile_numba_function_src(
227227
makevector_def_src,
228228
"makevector",
229229
{**globals(), **global_env},

pytensor/link/numba/dispatch/vectorize_codegen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from numba.core.types.misc import NoneType
1616
from numba.np import arrayobj
1717

18-
from pytensor.link.numba.cache import compile_and_cache_numba_function_src
18+
from pytensor.link.numba.cache import compile_numba_function_src
1919
from pytensor.link.numba.compile import numba_njit
2020

2121

@@ -53,7 +53,7 @@ def store_core_outputs({inp_signature}, {out_signature}):
5353
"""
5454
global_env = {"core_op_fn": core_op_fn}
5555

56-
func = compile_and_cache_numba_function_src(
56+
func = compile_numba_function_src(
5757
func_src,
5858
"store_core_outputs",
5959
{**globals(), **global_env},

pytensor/link/numba/linker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ def __init__(self, *args, vm: bool = False, **kwargs):
99
self.vm = vm
1010

1111
def fgraph_convert(self, fgraph, **kwargs):
12-
from pytensor.link.numba.dispatch import numba_funcify
12+
from pytensor.link.numba.compile import numba_funcify
1313

1414
return numba_funcify(fgraph, **kwargs)
1515

1616
def jit_compile(self, fn):
1717
if self.vm:
1818
return fn
1919
else:
20-
from pytensor.link.numba.compile import numba_njit
20+
from pytensor.link.numba.cache import numba_njit
2121

2222
jitted_fn = numba_njit(fn, final_function=True)
2323
return jitted_fn

tests/compile/function/test_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1466,7 +1466,7 @@ def zerosumnormal(name, *, sigma=1.0, size, model_logp):
14661466
return joined_inputs, [model_logp, model_dlogp]
14671467

14681468

1469-
@pytest.mark.parametrize("mode", ["NUMBA", "C", "C_VM", "NUMBA"][:1])
1469+
@pytest.mark.parametrize("mode", ["C", "C_VM", "NUMBA"][2:])
14701470
def test_radon_model_compile_benchmark(mode, radon_model, benchmark):
14711471
joined_inputs, [model_logp, model_dlogp] = radon_model
14721472
rng = np.random.default_rng(1)

0 commit comments

Comments
 (0)