diff --git a/keras/src/quantizers/quantizers_test.py b/keras/src/quantizers/quantizers_test.py index 1f0e8217778..d5671f25973 100644 --- a/keras/src/quantizers/quantizers_test.py +++ b/keras/src/quantizers/quantizers_test.py @@ -544,10 +544,9 @@ def quantize_fn(x): _, f_vjp = jax.vjp(quantize_fn, inputs) - # NOTE: When python version >= 3.10, the gradients are at - # `f_vjp.args[0].args[0][0]`. Otherwise, they are at - # `f_vjp.args[0].args[0][1]`. - if sys.version_info >= (3, 10): + if getattr(jax.config, "jax_vjp3", False): + input_gradients = f_vjp.opaque_residuals[0] + elif sys.version_info >= (3, 10): input_gradients = f_vjp.args[0].args[0][0] else: input_gradients = f_vjp.args[0].args[0][1]