Skip to content

Conversation

acmiyaguchi
Copy link

I was able to generate a nice computation graph from this. I had to fix a few things to make it play well with my setup. First is that the pyproject dependencies are too strict. When using uv to resolve this, it forces downgrades on my local version of jax which is 2 major versions behind at this point. My package also happens to be 3.9 compatible, and I think its good practice to use as low as a python version you can without going out on a limb. I didn't touch any of the poetry stuff, probably worth a bump if you do any maintenance in the future (this can be consolidated into the pyproject, i think pep 621 and 735 have been poetry/uv stuff easier to put into one place).

I added the static_argnums to this, so you can specify the positions where the arguments are static. The jit api also allows for static_argnames, but this requires introspecting the function, so I'm just exposing whatever was in the docs for the make_jaxpr function.

This should fix #26.

acmiyaguchi and others added 3 commits October 10, 2025 17:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

No way of rendering the graph when a function has static arguments

1 participant