-
Notifications
You must be signed in to change notification settings - Fork 146
Reliable caching of Graphs and individual Ops in numba backend #1637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
943f18d to
885aeea
Compare
|
Updated it's 2.5x faster than before |
|
Is this Good Enough(TM)? |
3468613 to
0969905
Compare
I've managed to cache the whole function graph (still in hack mode), which means if you recompile the exact same graph it's as fast to compile as the other backends (and executes faster, so that's nice). This is not so rare when working in jupyter notebooks and the like. For compiling slightly different versions (new test) it's 2x faster than before and no longer the slowest backend. The closest is the C (as opposed to C VM), which also compiles the whole graph. It's still 2x slower then the C VM. This is as expected, that's why the Theano guys moved away from the C as default. However that may actually be fine for now? We can tell users to switch to the CVM backend if the compile times are prohibitive? |
f3851a7 to
690a637
Compare
|
Neat! yeah, sounds like this is good enough to switch the default backend to numba. |
2a297b8 to
e4d64a1
Compare
|
Okay this is getting there. Missing some docstrings and minor cleanup. The big pain point are the blas/lapack functions that we'll have to handle manually as well, but I would leave those to a separate PR. I can re-run the whole numba CI in 5 minutes now, of which the uncacheable blas/lapack stuff takes 3-4 minutes. This will also unblock #1445 which was blocked by our clumsy/eager attempt to cache stuff, even the uncacheable. The new system allows us to cleanly inform higher-order functions when a sub-function is uncacheable and therefore the functions using it as well. |
e4d64a1 to
ac2605b
Compare
Codecov Report❌ Patch coverage is ❌ Your patch check has failed because the patch coverage (91.06%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1637 +/- ##
==========================================
+ Coverage 81.67% 81.75% +0.08%
==========================================
Files 244 248 +4
Lines 53558 53822 +264
Branches 9433 9459 +26
==========================================
+ Hits 43741 44002 +261
+ Misses 7337 7333 -4
- Partials 2480 2487 +7
🚀 New features to boost your workflow:
|
ebaf3f7 to
f32dcbb
Compare
|
Just wrestling with mypy at this point, everything ready otherwise |
We don't need it. This is just for code-coverage and supposedly to find trivial python errors in python-mode, instead of the unreadable numba tracebacks when developing. |
efaf2c7 to
00688f9
Compare
|
Tests are passing and all immediate TODO have been addressed. |
00688f9 to
ebb2f00
Compare
2d6b776 to
c0249f6
Compare
These tests were covering things that don't exist anymore. params in python perform method of Ops, or misbehavior of an Op not respecting the signature
Direct import is not properly mocked by tests when trying to run `compare_numba_and_py` with `eval_obj_mode=True`
16ea993 to
eca6cf1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great, I'll be super glad if we can reduce the numba compile time!
Some of the cache key generation feels to me like it should just happen in Op or subclasses of that directly, instead of in the numba backend. Sometimes we might want to combine Ops with different identities into one numba function, but can't we usually just ask the Op itself to give us a cache key, and move some of the new code there?
| @numba_njit | ||
| def opfromgraph(*inputs): | ||
| return fgraph_fn(*inputs) | ||
| The default cache key is based on the string representations of: `type(op)` and the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should also add the pytensor version to this?
If we don't, I think we have to always increase the counter for any implementation change from now on, and I'm a bit scared of the kind of bugs that this could lead to if we ever forget or mess it up somehow. (Only reproducible of you run the old version of pytensor first...).
This would slow down compilation after every pytensor update, but I think that might be worth the cost.
This would leave the counter for development.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do it like this for C and never had big issues. Also we are pretty liberal with releasing PyTensor these days and moreover we have the automatic release name which for dev means it changes with every commit, so it would be a lot of waste for devs. In any case there is a designated place here in the cache locator that can be used to invalidate all previous caches at once:

| print(f"{op} of type {type(op)} will not be cached by PyTensor.\n") # noqa: T201 | ||
| return jitable_func, None | ||
| else: | ||
| op_name = op.__class__.__name__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this name guaranteed to be unique? Or could we end up with a function from the same Op in the globals?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure it matters, I assume compiling won't override the existing globals, and being part of the globals of the compiled function it will simply not be used? I don't think we were worried about creating unique names when compiling functions in the dispatches (scan, store_core_outputs, opfromgraph). I'll double check
| fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs)) | ||
|
|
||
| if len(op.fgraph.outputs) == 1: | ||
| def register_funcify_default_op_cache_key(op_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we somehow keep the old API? I think this breaks all user defined numba ops, and I don't see a good reason to do so, couldn't we just make the old numba_funcify.register call this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is backward compatible. The functions we call to compile now fallback to the old numba_funcify. It just means it won't cache those graphs.
Furthermore the new decorators also register into the old numba_funcify, just stripping away the cache key, in case users/libraries were calling those directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added new tests that cover the multiple APIs being cross-compatible
I think this is backend specific. The same Op may require distinct functions in one backend (like numba does with Subtensor) but a single one in another (jax is just |
eca6cf1 to
06f03b4
Compare
This way they can be fully cached when re-running tests
Partially reverts d894350
06f03b4 to
9158fce
Compare


Cherry picking changes from #1604
It seems that simply improving the caching of individual Ops gives us a lot of speedup when still jitting the whole graph.
THIS is still very dirty, and perhaps overkill. We're compile+exec every single Op even if they don't need our custom cache control. OTOH it's quite hard to know what will numba accept caching for or not, and as mentioned here, numba cache invalidation also leaves a lot to be desired: #1604 (comment)
The compile+exec is needed to lift variables/functions out of the function closure into the global scope.
Otherwise numba will look into those to check if the cache is stale (numba always has the last word on whether a cache is stale or not). Depending on how they are serialized, these variables can look different even if they haven't changed and the function is exactly the same as before.
This is pure gaming numba to our purposes.
Benchmarking
And here are the timings for compiling the radon model repeatedly:
And compiling slight variations of the same model (after clearing the cache for both backends)
And this comes at no cost in evaluation runtime. Unlike the VM approach which is 2x slower in the C backend and many times over in the naive impl in #1604 (Numba overhead per individual jitted function added up there)
Conclusion so far
C_VM, but faster than the "equivalent" fully compiled C. Compared to itself, it is 2x faster than before, so in general numba will compile anywhere between 2x and 10x faster than before, for similar kinds of models.TODO:
sha256(sha256(key1), sha256(key2))has much higher collision chances thansha256(key1, key2)Follow Up PRs
cache=Truefailures with locally defined functions numba/numba#10098 (comment)