Skip to content

Commit 8709393

Browse files
committed
Fix non-vm NUMBA
1 parent 6b5d199 commit 8709393

File tree

5 files changed

+10
-110
lines changed

5 files changed

+10
-110
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,15 @@ def numba_funcify_FunctionGraph(
270270
jit_nodes: bool = False,
271271
**kwargs,
272272
):
273+
def numba_funcify_wrapper(*args, **kwargs):
274+
result = numba_funcify(*args, **kwargs)
275+
if isinstance(result, tuple):
276+
return result[0]
277+
return result
278+
273279
return fgraph_to_python(
274280
fgraph,
275-
op_conversion_fn=numba_funcify_njit if jit_nodes else numba_funcify,
281+
op_conversion_fn=numba_funcify_njit if jit_nodes else numba_funcify_wrapper,
276282
type_conversion_fn=numba_typify,
277283
fgraph_name=fgraph_name,
278284
**kwargs,

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
Subtensor,
1717
)
1818
from pytensor.tensor.type_other import NoneTypeT, SliceType
19-
from pytensor.utils import hash_from_code
2019

2120

2221
@numba_funcify.register(Subtensor)
@@ -102,7 +101,7 @@ def {function_name}({", ".join(input_names)}):
102101
function_name=function_name,
103102
global_env=globals() | {"np": np},
104103
)
105-
return numba_njit(func, boundscheck=True), hash_from_code(subtensor_def_src)
104+
return numba_njit(func, boundscheck=True)
106105

107106

108107
@numba_funcify.register(AdvancedSubtensor)

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,7 @@ def alloc(val, {", ".join(shape_var_names)}):
110110
{**globals(), **global_env},
111111
)
112112

113-
return (
114-
pytensor.link.numba.compile.numba_njit(alloc_fn),
115-
hash_from_code(alloc_def_src),
116-
)
113+
return pytensor.link.numba.compile.numba_njit(alloc_fn)
117114

118115

119116
@numba_funcify.register(ARange)

pytensor/link/numba/linker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def jit_compile(self, fn):
1919
else:
2020
from pytensor.link.numba.compile import numba_njit
2121

22-
jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False)
22+
jitted_fn = numba_njit(fn, final_function=True)
2323
return jitted_fn
2424

2525
def create_thunk_inputs(self, storage_map):

tests/link/numba/test_compile.py

Lines changed: 0 additions & 102 deletions
This file was deleted.

0 commit comments

Comments
 (0)