Skip to content

Commit 6f37855

Browse files
mattjjyashk2810
andcommitted
update test after jax.config.jax_vjp3 is enabled
Co-authored-by: Yash Katariya <yashkatariya@google.com>
1 parent 18f79d6 commit 6f37855

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

keras/src/quantizers/quantizers_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -544,10 +544,9 @@ def quantize_fn(x):
544544

545545
_, f_vjp = jax.vjp(quantize_fn, inputs)
546546

547-
# NOTE: When python version >= 3.10, the gradients are at
548-
# `f_vjp.args[0].args[0][0]`. Otherwise, they are at
549-
# `f_vjp.args[0].args[0][1]`.
550-
if sys.version_info >= (3, 10):
547+
if jax.config.jax_vjp3.value:
548+
input_gradients = f_vjp.opaque_residuals[0]
549+
elif sys.version_info >= (3, 10):
551550
input_gradients = f_vjp.args[0].args[0][0]
552551
else:
553552
input_gradients = f_vjp.args[0].args[0][1]

0 commit comments

Comments
 (0)