Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions jpviz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .dot import draw_dot_graph


def draw(f, collapse_primitives=True, show_avals=True) -> typing.Callable:
def draw(f, collapse_primitives=True, show_avals=True, static_argnums=()) -> typing.Callable:
"""
Visualise a JAX computation graph

Expand Down Expand Up @@ -45,6 +45,11 @@ def bar(x):
show_avals: bool
If `True` then type information will be
included on node labels
static_argnums: static_argnums: int | Iterable[int] = ()
Optional sequence of argument indices to treat
as static (compile time constant) rather than
dynamic (runtime value). See `jax.jit` documentation
for details

Returns
-------
Expand All @@ -54,7 +59,7 @@ def bar(x):
"""

def _inner_draw(*args, **kwargs) -> pydot.Graph:
jaxpr = jax.make_jaxpr(f)(*args, **kwargs)
jaxpr = jax.make_jaxpr(f, static_argnums=static_argnums)(*args, **kwargs)
return draw_dot_graph(jaxpr, collapse_primitives, show_avals)

return _inner_draw
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ classifiers = [
]

[tool.poetry.dependencies]
python = "^3.10,<3.13"
jax = "^0.4.14"
pydot = "^1.4.2"
scipy = "^1.10.0"
python = ">=3.9,<3.13"
jax = ">=0.4.0"
pydot = "*"
scipy = "*"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.3.3"
Expand Down