diff --git a/grad_dft/utils/chunk.py b/grad_dft/utils/chunk.py index 1298754..f8060fc 100644 --- a/grad_dft/utils/chunk.py +++ b/grad_dft/utils/chunk.py @@ -19,7 +19,7 @@ from jax import numpy as jnp from jax.tree_util import tree_leaves, tree_map -from jax import linear_util as lu +from jax.extend import linear_util as lu from jax.api_util import argnums_partial from .types import Array diff --git a/requirements.txt b/requirements.txt index a236acd..62d8469 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ typeguard==2.13.3 typing_extensions>=4.8.0 jaxtyping pytest>=7.4.3 - +numpy<2.0.0 +python>=3.9