|
3 | 3 | import warnings |
4 | 4 | from collections.abc import Callable |
5 | 5 | from functools import singledispatch |
| 6 | +from hashlib import sha256 |
| 7 | +from pickle import dumps |
6 | 8 |
|
7 | 9 | import numba |
8 | 10 | import numpy as np |
|
17 | 19 | from pytensor.compile.builders import OpFromGraph |
18 | 20 | from pytensor.compile.function.types import add_supervisor_to_fgraph |
19 | 21 | from pytensor.compile.ops import DeepCopyOp |
| 22 | +from pytensor.graph import Constant |
20 | 23 | from pytensor.graph.fg import FunctionGraph |
21 | 24 | from pytensor.graph.op import Op |
22 | 25 | from pytensor.ifelse import IfElse |
23 | 26 | from pytensor.link.numba.cache import ( |
24 | 27 | numba_njit_and_cache, |
| 28 | + register_funcify_and_cache_key, |
25 | 29 | register_funcify_default_op_cache_key, |
26 | 30 | ) |
27 | 31 | from pytensor.link.numba.compile import ( |
@@ -226,21 +230,52 @@ def numba_funcify_fallback( |
226 | 230 | return generate_fallback_impl(op, node, storage_map, **kwargs) |
227 | 231 |
|
228 | 232 |
|
229 | | -@numba_funcify.register(FunctionGraph) |
| 233 | +def key_for_constant(data): |
| 234 | + """Create a cache key for a constant value.""" |
| 235 | + # TODO: This is just a placeholder |
| 236 | + if isinstance(data, (int | float | bool | type(None))): |
| 237 | + return str(data) |
| 238 | + try: |
| 239 | + # For NumPy arrays |
| 240 | + return sha256(data.tobytes()).hexdigest() |
| 241 | + except AttributeError: |
| 242 | + # Fallback for other types |
| 243 | + return sha256(dumps(data)).hexdigest() |
| 244 | + |
| 245 | + |
| 246 | +@register_funcify_and_cache_key(FunctionGraph) |
230 | 247 | def numba_funcify_FunctionGraph( |
231 | 248 | fgraph, |
232 | 249 | node=None, |
233 | 250 | fgraph_name="numba_funcified_fgraph", |
234 | 251 | **kwargs, |
235 | 252 | ): |
236 | | - # TODO: Create hash key for whole graph |
237 | | - return fgraph_to_python( |
| 253 | + cache_keys = [] |
| 254 | + |
| 255 | + def op_conversion_and_key_collection(*args, **kwargs): |
| 256 | + func, key = numba_njit_and_cache(*args, **kwargs) |
| 257 | + cache_keys.append(key) |
| 258 | + return func |
| 259 | + |
| 260 | + def type_conversion_and_key_collection(value, variable, **kwargs): |
| 261 | + if isinstance(variable, Constant): |
| 262 | + cache_keys.append(key_for_constant(value)) |
| 263 | + return numba_typify(value, variable=variable, **kwargs) |
| 264 | + |
| 265 | + py_func = fgraph_to_python( |
238 | 266 | fgraph, |
239 | | - op_conversion_fn=numba_njit_and_cache, |
240 | | - type_conversion_fn=numba_typify, |
| 267 | + op_conversion_fn=op_conversion_and_key_collection, |
| 268 | + type_conversion_fn=type_conversion_and_key_collection, |
241 | 269 | fgraph_name=fgraph_name, |
242 | 270 | **kwargs, |
243 | 271 | ) |
| 272 | + if any(key is None for key in cache_keys): |
| 273 | + fgraph_key = None |
| 274 | + else: |
| 275 | + fgraph_key = sha256( |
| 276 | + str((tuple(cache_keys), len(fgraph.inputs), len(fgraph.outputs))).encode() |
| 277 | + ).hexdigest() |
| 278 | + return py_func, fgraph_key |
244 | 279 |
|
245 | 280 |
|
246 | 281 | @numba_funcify.register(OpFromGraph) |
|
0 commit comments