Skip to content

Commit 3677f33

Browse files
committed
Cache more Ops
1 parent 9d52726 commit 3677f33

File tree

4 files changed

+37
-17
lines changed

4 files changed

+37
-17
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from numba import types
1515
from numba.core.errors import NumbaWarning, TypingError
1616
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
17-
from numba.extending import box, overload, register_jitable as _register_jitable
17+
from numba.extending import box, overload
18+
from numba.extending import register_jitable as _register_jitable
1819

1920
from pytensor import In, config
2021
from pytensor.compile import NUMBA
@@ -25,11 +26,9 @@
2526
from pytensor.graph.fg import FunctionGraph
2627
from pytensor.graph.type import Type
2728
from pytensor.ifelse import IfElse
29+
from pytensor.link.numba.cache import compile_and_cache_numba_function_src
2830
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
29-
from pytensor.link.utils import (
30-
compile_function_src,
31-
fgraph_to_python,
32-
)
31+
from pytensor.link.utils import fgraph_to_python
3332
from pytensor.scalar.basic import ScalarType
3433
from pytensor.sparse import SparseTensorType
3534
from pytensor.tensor.basic import Nonzero
@@ -40,6 +39,7 @@
4039
from pytensor.tensor.sort import ArgSortOp, SortOp
4140
from pytensor.tensor.type import TensorType
4241
from pytensor.tensor.type_other import MakeSlice, NoneConst
42+
from pytensor.utils import hash_from_code
4343

4444

4545
def global_numba_func(func):
@@ -562,7 +562,12 @@ def specify_shape(x, {create_arg_string(shape_input_names)}):
562562
"""
563563
)
564564

565-
specify_shape = compile_function_src(func, "specify_shape", globals())
565+
specify_shape = compile_and_cache_numba_function_src(
566+
func,
567+
"specify_shape",
568+
globals(),
569+
key=hash_from_code(func),
570+
)
566571
return numba_njit(specify_shape)
567572

568573

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22

33
from pytensor.graph import Type
4+
from pytensor.link.numba.cache import compile_and_cache_numba_function_src
45
from pytensor.link.numba.dispatch import numba_funcify
56
from pytensor.link.numba.dispatch.basic import generate_fallback_impl, numba_njit
6-
from pytensor.link.utils import compile_function_src, unique_name_generator
7+
from pytensor.link.utils import unique_name_generator
78
from pytensor.tensor import TensorType
89
from pytensor.tensor.rewriting.subtensor import is_full_slice
910
from pytensor.tensor.subtensor import (
@@ -15,6 +16,7 @@
1516
Subtensor,
1617
)
1718
from pytensor.tensor.type_other import NoneTypeT, SliceType
19+
from pytensor.utils import hash_from_code
1820

1921

2022
@numba_funcify.register(Subtensor)
@@ -95,10 +97,11 @@ def {function_name}({", ".join(input_names)}):
9597
return np.asarray(z)
9698
"""
9799

98-
func = compile_function_src(
100+
func = compile_and_cache_numba_function_src(
99101
subtensor_def_src,
100102
function_name=function_name,
101103
global_env=globals() | {"np": np},
104+
key=hash_from_code(subtensor_def_src),
102105
)
103106
return numba_njit(func, boundscheck=True)
104107

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import numpy as np
44

5+
from pytensor.link.numba.cache import compile_and_cache_numba_function_src
56
from pytensor.link.numba.dispatch import basic as numba_basic
67
from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcify
7-
from pytensor.link.utils import compile_function_src, unique_name_generator
8+
from pytensor.link.utils import unique_name_generator
89
from pytensor.tensor.basic import (
910
Alloc,
1011
AllocEmpty,
@@ -17,6 +18,7 @@
1718
Split,
1819
TensorFromScalar,
1920
)
21+
from pytensor.utils import hash_from_code
2022

2123

2224
@numba_funcify.register(AllocEmpty)
@@ -49,8 +51,11 @@ def allocempty({", ".join(shape_var_names)}):
4951
return np.empty(scalar_shape, dtype)
5052
"""
5153

52-
alloc_fn = compile_function_src(
53-
alloc_def_src, "allocempty", {**globals(), **global_env}
54+
alloc_fn = compile_and_cache_numba_function_src(
55+
alloc_def_src,
56+
"allocempty",
57+
{**globals(), **global_env},
58+
key=hash_from_code(alloc_def_src),
5459
)
5560

5661
return numba_basic.numba_njit(alloc_fn)
@@ -93,7 +98,12 @@ def alloc(val, {", ".join(shape_var_names)}):
9398
res[...] = val
9499
return res
95100
"""
96-
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})
101+
alloc_fn = compile_and_cache_numba_function_src(
102+
alloc_def_src,
103+
"alloc",
104+
{**globals(), **global_env},
105+
key=hash_from_code(alloc_def_src),
106+
)
97107

98108
return numba_basic.numba_njit(alloc_fn)
99109

@@ -212,8 +222,11 @@ def makevector({", ".join(input_names)}):
212222
return np.array({create_list_string(input_names)}, dtype=dtype)
213223
"""
214224

215-
makevector_fn = compile_function_src(
216-
makevector_def_src, "makevector", {**globals(), **global_env}
225+
makevector_fn = compile_and_cache_numba_function_src(
226+
makevector_def_src,
227+
"makevector",
228+
{**globals(), **global_env},
229+
key=f"MakeVector({op.dtype})",
217230
)
218231

219232
return numba_basic.numba_njit(makevector_fn)

pytensor/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,8 @@ def hash_from_code(msg: str | bytes) -> str:
191191
# but Python 3 (unicode) strings don't.
192192
if isinstance(msg, str):
193193
msg = msg.encode()
194-
# Python 3 does not like module names that start with
195-
# a digit.
196-
return "m" + hashlib.sha256(msg).hexdigest()
194+
# Python 3 does not like module names that start with a digit.
195+
return f"m{hashlib.sha256(msg).hexdigest()}"
197196

198197

199198
def uniq(seq: Sequence) -> list:

0 commit comments

Comments
 (0)