diff --git a/jpviz/__init__.py b/jpviz/__init__.py index 696938a..c50a765 100644 --- a/jpviz/__init__.py +++ b/jpviz/__init__.py @@ -6,7 +6,7 @@ from .dot import draw_dot_graph -def draw(f, collapse_primitives=True, show_avals=True) -> typing.Callable: +def draw(f, collapse_primitives=True, show_avals=True, static_argnums=()) -> typing.Callable: """ Visualise a JAX computation graph @@ -45,6 +45,11 @@ def bar(x): show_avals: bool If `True` then type information will be included on node labels + static_argnums: static_argnums: int | Iterable[int] = () + Optional sequence of argument indices to treat + as static (compile time constant) rather than + dynamic (runtime value). See `jax.jit` documentation + for details Returns ------- @@ -54,7 +59,7 @@ def bar(x): """ def _inner_draw(*args, **kwargs) -> pydot.Graph: - jaxpr = jax.make_jaxpr(f)(*args, **kwargs) + jaxpr = jax.make_jaxpr(f, static_argnums=static_argnums)(*args, **kwargs) return draw_dot_graph(jaxpr, collapse_primitives, show_avals) return _inner_draw diff --git a/pyproject.toml b/pyproject.toml index ef700ca..49bd0b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,10 +23,10 @@ classifiers = [ ] [tool.poetry.dependencies] -python = "^3.10,<3.13" -jax = "^0.4.14" -pydot = "^1.4.2" -scipy = "^1.10.0" +python = ">=3.9,<3.13" +jax = ">=0.4.0" +pydot = "*" +scipy = "*" [tool.poetry.group.dev.dependencies] pre-commit = "^3.3.3"