Fix dependency pins and enable static arguments #27
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I was able to generate a nice computation graph from this. I had to fix a few things to make it play well with my setup. First is that the pyproject dependencies are too strict. When using uv to resolve this, it forces downgrades on my local version of jax which is 2 major versions behind at this point. My package also happens to be 3.9 compatible, and I think its good practice to use as low as a python version you can without going out on a limb. I didn't touch any of the poetry stuff, probably worth a bump if you do any maintenance in the future (this can be consolidated into the pyproject, i think pep 621 and 735 have been poetry/uv stuff easier to put into one place).
I added the static_argnums to this, so you can specify the positions where the arguments are static. The jit api also allows for static_argnames, but this requires introspecting the function, so I'm just exposing whatever was in the docs for the make_jaxpr function.
This should fix #26.