We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 18f79d6 commit 6f37855Copy full SHA for 6f37855
keras/src/quantizers/quantizers_test.py
@@ -544,10 +544,9 @@ def quantize_fn(x):
544
545
_, f_vjp = jax.vjp(quantize_fn, inputs)
546
547
- # NOTE: When python version >= 3.10, the gradients are at
548
- # `f_vjp.args[0].args[0][0]`. Otherwise, they are at
549
- # `f_vjp.args[0].args[0][1]`.
550
- if sys.version_info >= (3, 10):
+ if jax.config.jax_vjp3.value:
+ input_gradients = f_vjp.opaque_residuals[0]
+ elif sys.version_info >= (3, 10):
551
input_gradients = f_vjp.args[0].args[0][0]
552
else:
553
input_gradients = f_vjp.args[0].args[0][1]
0 commit comments