Skip to content

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Oct 7, 2025

Cherry picking changes from #1604

It seems that simply improving the caching of individual Ops gives us a lot of speedup when still jitting the whole graph.

THIS is still very dirty, and perhaps overkill. We're compile+exec every single Op even if they don't need our custom cache control. OTOH it's quite hard to know what will numba accept caching for or not, and as mentioned here, numba cache invalidation also leaves a lot to be desired: #1604 (comment)

The compile+exec is needed to lift variables/functions out of the function closure into the global scope.

Otherwise numba will look into those to check if the cache is stale (numba always has the last word on whether a cache is stale or not). Depending on how they are serialized, these variables can look different even if they haven't changed and the function is exactly the same as before.

This is pure gaming numba to our purposes.

Benchmarking

And here are the timings for compiling the radon model repeatedly:

Before
--------------------------------------------------------------------------------------------------------- benchmark: 3 tests --------------------------------------------------------------------------------------------------------
Name (time in ms)                                             Min                   Max                  Mean              StdDev                Median                 IQR            Outliers     OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_radon_model_repeated_compile_benchmark[C_VM]        641.6810 (1.0)        923.9313 (1.03)       728.9724 (1.01)     118.2841 (1.16)       668.2491 (1.0)      149.5554 (1.27)          1;0  1.3718 (0.99)          5           1
test_radon_model_repeated_compile_benchmark[C]           653.8087 (1.02)       895.6774 (1.0)        723.1066 (1.0)      101.6445 (1.0)        675.7941 (1.01)     117.9712 (1.0)           1;0  1.3829 (1.0)           5           1
test_radon_model_repeated_compile_benchmark[NUMBA]     7,505.5632 (11.70)    8,140.9498 (9.09)     7,836.1051 (10.84)    241.2376 (2.37)     7,894.4678 (11.81)    332.0119 (2.81)          2;0  0.1276 (0.09)          5           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After
------------------------------------------------------------------------------------------------------ benchmark: 3 tests -----------------------------------------------------------------------------------------------------
Name (time in ms)                                           Min                   Max                Mean              StdDev              Median                 IQR            Outliers     OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_radon_model_repeated_compile_benchmark[C]         676.8789 (1.0)      1,028.5802 (1.05)     786.7192 (1.0)      140.3314 (1.41)     748.9989 (1.0)      139.2424 (1.25)          1;0  1.2711 (1.0)           5           1
test_radon_model_repeated_compile_benchmark[C_VM]      740.0859 (1.09)       980.4151 (1.0)      811.9622 (1.03)      99.4709 (1.0)      759.7109 (1.01)     111.1999 (1.0)           1;0  1.2316 (0.97)          5           1
test_radon_model_repeated_compile_benchmark[NUMBA]     762.9275 (1.13)     1,027.3377 (1.05)     888.7869 (1.13)     102.8226 (1.03)     900.3118 (1.20)     153.7402 (1.38)          2;0  1.1251 (0.89)          5           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

And compiling slight variations of the same model (after clearing the cache for both backends)

Before
------------------------------------------------------------------------
Name (time in s)                                           Runtime          
------------------------------------------------------------------------
test_radon_model_compile_variants_benchmark[C_VM]      14.6317 (1.0)    
test_radon_model_compile_variants_benchmark[C]         37.0302 (2.53)   
test_radon_model_compile_variants_benchmark[NUMBA]     51.8231 (3.54)   
------------------------------------------------------------------------

After
-----------------------------------------------------------------------
Name (time in s)                                           Runtime         
-----------------------------------------------------------------------
test_radon_model_compile_variants_benchmark[C_VM]      16.4719 (1.0)   
test_radon_model_compile_variants_benchmark[NUMBA]     26.0248 (1.58)  
test_radon_model_compile_variants_benchmark[C]         37.1087 (2.25)  
-----------------------------------------------------------------------

And this comes at no cost in evaluation runtime. Unlike the VM approach which is 2x slower in the C backend and many times over in the naive impl in #1604 (Numba overhead per individual jitted function added up there)

-------------------------------------------------------------------------------------------------- benchmark: 4 tests --------------------------------------------------------------------------------------------------
Name (time in us)                                  Min                   Max               Mean             StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_radon_model_call_benchmark[NUMBA]         21.3900 (1.0)         35.4870 (1.0)      25.0974 (1.0)       5.9128 (1.88)     22.2920 (1.0)       5.2898 (13.20)         1;1       39.8448 (1.0)           5           1
test_radon_model_call_benchmark[C]             22.6220 (1.06)        72.1550 (2.03)     27.6587 (1.10)      3.1417 (1.0)      27.8320 (1.25)      0.4007 (1.0)     2577;3297       36.1550 (0.91)       9691           1
test_radon_model_call_benchmark[C_VM_NOGC]     33.3120 (1.56)     1,177.2980 (33.18)    46.5852 (1.86)     16.1183 (5.13)     40.5760 (1.82)     12.9245 (32.25)     790;286       21.4661 (0.54)       8988           1
test_radon_model_call_benchmark[C_VM]          44.9940 (2.10)       633.7580 (17.86)    52.8585 (2.11)     11.7635 (3.74)     52.3990 (2.35)      6.5520 (16.35)     570;652       18.9184 (0.47)       8000           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Conclusion so far

  • If you're recompiling the very same function numba is now as fast as the other backends (and 10x faster than before)
  • If you're compiling slightly different functions numba is 2x slower than the default backend C_VM, but faster than the "equivalent" fully compiled C. Compared to itself, it is 2x faster than before, so in general numba will compile anywhere between 2x and 10x faster than before, for similar kinds of models.

TODO:

  • Hash constants correctly
  • Investigate whether doing sha256(sha256(key1), sha256(key2)) has much higher collision chances than sha256(key1, key2)
  • Split numba benchmark tests and add a numba no-cache mode to the test

Follow Up PRs

@twiecki
Copy link
Member

twiecki commented Oct 8, 2025

image

the suspense is killing me

@ricardoV94 ricardoV94 force-pushed the cherry_pick_numba_cache branch from 943f18d to 885aeea Compare October 8, 2025 20:24
@ricardoV94
Copy link
Member Author

Updated it's 2.5x faster than before

@twiecki
Copy link
Member

twiecki commented Oct 9, 2025

Is this Good Enough(TM)?

@ricardoV94 ricardoV94 force-pushed the cherry_pick_numba_cache branch 2 times, most recently from 3468613 to 0969905 Compare October 9, 2025 11:57
@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 9, 2025

Is this Good Enough(TM)?

I've managed to cache the whole function graph (still in hack mode), which means if you recompile the exact same graph it's as fast to compile as the other backends (and executes faster, so that's nice). This is not so rare when working in jupyter notebooks and the like.

For compiling slightly different versions (new test) it's 2x faster than before and no longer the slowest backend. The closest is the C (as opposed to C VM), which also compiles the whole graph.

It's still 2x slower then the C VM. This is as expected, that's why the Theano guys moved away from the C as default. However that may actually be fine for now? We can tell users to switch to the CVM backend if the compile times are prohibitive?

@ricardoV94 ricardoV94 changed the title Cache numba cache Trust me numba: you can use this cache Oct 9, 2025
@ricardoV94 ricardoV94 force-pushed the cherry_pick_numba_cache branch 2 times, most recently from f3851a7 to 690a637 Compare October 9, 2025 16:03
@twiecki
Copy link
Member

twiecki commented Oct 10, 2025

Neat! yeah, sounds like this is good enough to switch the default backend to numba.

@ricardoV94 ricardoV94 force-pushed the cherry_pick_numba_cache branch 7 times, most recently from 2a297b8 to e4d64a1 Compare October 23, 2025 21:17
@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 23, 2025

Okay this is getting there. Missing some docstrings and minor cleanup. The big pain point are the blas/lapack functions that we'll have to handle manually as well, but I would leave those to a separate PR.

I can re-run the whole numba CI in 5 minutes now, of which the uncacheable blas/lapack stuff takes 3-4 minutes.

This will also unblock #1445 which was blocked by our clumsy/eager attempt to cache stuff, even the uncacheable. The new system allows us to cleanly inform higher-order functions when a sub-function is uncacheable and therefore the functions using it as well.

@ricardoV94 ricardoV94 force-pushed the cherry_pick_numba_cache branch from e4d64a1 to ac2605b Compare October 23, 2025 21:29
@codecov
Copy link

codecov bot commented Oct 23, 2025

Codecov Report

❌ Patch coverage is 91.06830% with 51 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.75%. Comparing base (60ba7c7) to head (9158fce).
⚠️ Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/basic.py 86.55% 12 Missing and 4 partials ⚠️
pytensor/link/numba/cache.py 85.18% 4 Missing and 4 partials ⚠️
pytensor/bin/pytensor_cache.py 16.66% 5 Missing ⚠️
pytensor/link/numba/dispatch/compile_ops.py 92.64% 3 Missing and 2 partials ⚠️
pytensor/link/numba/dispatch/random.py 66.66% 5 Missing ⚠️
pytensor/link/numba/dispatch/blockwise.py 77.77% 4 Missing ⚠️
pytensor/link/numba/dispatch/elemwise.py 96.49% 2 Missing ⚠️
pytensor/link/numba/dispatch/scan.py 77.77% 1 Missing and 1 partial ⚠️
pytensor/link/numba/dispatch/vectorize_codegen.py 75.00% 1 Missing and 1 partial ⚠️
pytensor/configdefaults.py 66.66% 1 Missing ⚠️
... and 1 more

❌ Your patch check has failed because the patch coverage (91.06%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1637      +/-   ##
==========================================
+ Coverage   81.67%   81.75%   +0.08%     
==========================================
  Files         244      248       +4     
  Lines       53558    53822     +264     
  Branches     9433     9459      +26     
==========================================
+ Hits        43741    44002     +261     
+ Misses       7337     7333       -4     
- Partials     2480     2487       +7     
Files with missing lines Coverage Δ
pytensor/compile/mode.py 85.13% <100.00%> (+0.13%) ⬆️
pytensor/configparser.py 92.60% <100.00%> (+0.04%) ⬆️
pytensor/link/numba/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/extra_ops.py 96.92% <100.00%> (+0.82%) ⬆️
...sor/link/numba/dispatch/linalg/decomposition/lu.py 66.66% <100.00%> (ø)
...or/link/numba/dispatch/linalg/solve/tridiagonal.py 55.39% <100.00%> (ø)
pytensor/link/numba/dispatch/nlinalg.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/shape.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/signal/conv.py 32.69% <100.00%> (ø)
pytensor/link/numba/dispatch/slinalg.py 68.54% <100.00%> (+0.14%) ⬆️
... and 16 more

... and 7 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 force-pushed the cherry_pick_numba_cache branch 3 times, most recently from ebaf3f7 to f32dcbb Compare October 25, 2025 11:42
@ricardoV94 ricardoV94 marked this pull request as ready for review October 25, 2025 11:47
@ricardoV94
Copy link
Member Author

Just wrestling with mypy at this point, everything ready otherwise

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 27, 2025

Do you know of any places where we need it?

We don't need it. This is just for code-coverage and supposedly to find trivial python errors in python-mode, instead of the unreadable numba tracebacks when developing.

@ricardoV94 ricardoV94 force-pushed the cherry_pick_numba_cache branch from efaf2c7 to 00688f9 Compare October 27, 2025 15:25
@ricardoV94
Copy link
Member Author

Tests are passing and all immediate TODO have been addressed.

@ricardoV94 ricardoV94 force-pushed the cherry_pick_numba_cache branch from 00688f9 to ebb2f00 Compare October 28, 2025 09:00
@ricardoV94 ricardoV94 force-pushed the cherry_pick_numba_cache branch 2 times, most recently from 2d6b776 to c0249f6 Compare October 28, 2025 09:54
These tests were covering things that don't exist anymore. params in python perform method of Ops, or misbehavior of an Op not respecting the signature
Direct import is not properly mocked by tests when trying to run `compare_numba_and_py` with `eval_obj_mode=True`
@ricardoV94 ricardoV94 force-pushed the cherry_pick_numba_cache branch 3 times, most recently from 16ea993 to eca6cf1 Compare October 28, 2025 12:26
Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, I'll be super glad if we can reduce the numba compile time!

Some of the cache key generation feels to me like it should just happen in Op or subclasses of that directly, instead of in the numba backend. Sometimes we might want to combine Ops with different identities into one numba function, but can't we usually just ask the Op itself to give us a cache key, and move some of the new code there?

@numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)
The default cache key is based on the string representations of: `type(op)` and the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should also add the pytensor version to this?
If we don't, I think we have to always increase the counter for any implementation change from now on, and I'm a bit scared of the kind of bugs that this could lead to if we ever forget or mess it up somehow. (Only reproducible of you run the old version of pytensor first...).
This would slow down compilation after every pytensor update, but I think that might be worth the cost.
This would leave the counter for development.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do it like this for C and never had big issues. Also we are pretty liberal with releasing PyTensor these days and moreover we have the automatic release name which for dev means it changes with every commit, so it would be a lot of waste for devs. In any case there is a designated place here in the cache locator that can be used to invalidate all previous caches at once:
image

print(f"{op} of type {type(op)} will not be cached by PyTensor.\n") # noqa: T201
return jitable_func, None
else:
op_name = op.__class__.__name__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this name guaranteed to be unique? Or could we end up with a function from the same Op in the globals?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure it matters, I assume compiling won't override the existing globals, and being part of the globals of the compiled function it will simply not be used? I don't think we were worried about creating unique names when compiling functions in the dispatches (scan, store_core_outputs, opfromgraph). I'll double check

fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))

if len(op.fgraph.outputs) == 1:
def register_funcify_default_op_cache_key(op_type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we somehow keep the old API? I think this breaks all user defined numba ops, and I don't see a good reason to do so, couldn't we just make the old numba_funcify.register call this function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is backward compatible. The functions we call to compile now fallback to the old numba_funcify. It just means it won't cache those graphs.

Furthermore the new decorators also register into the old numba_funcify, just stripping away the cache key, in case users/libraries were calling those directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added new tests that cover the multiple APIs being cross-compatible

@ricardoV94
Copy link
Member Author

ricardoV94 commented Nov 3, 2025

but can't we usually just ask the Op itself to give us a cache key, and move some of the new code there?

I think this is backend specific. The same Op may require distinct functions in one backend (like numba does with Subtensor) but a single one in another (jax is just x[indices]). But I may be missing what you had in mind specifically?

@ricardoV94 ricardoV94 force-pushed the cherry_pick_numba_cache branch from eca6cf1 to 06f03b4 Compare November 3, 2025 13:03
@ricardoV94 ricardoV94 force-pushed the cherry_pick_numba_cache branch from 06f03b4 to 9158fce Compare November 3, 2025 13:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants