Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions keras/src/quantizers/quantizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,10 +544,9 @@ def quantize_fn(x):

_, f_vjp = jax.vjp(quantize_fn, inputs)

# NOTE: When python version >= 3.10, the gradients are at
# `f_vjp.args[0].args[0][0]`. Otherwise, they are at
# `f_vjp.args[0].args[0][1]`.
if sys.version_info >= (3, 10):
if jax.config.jax_vjp3:
input_gradients = f_vjp.opaque_residuals[0]
elif sys.version_info >= (3, 10):
input_gradients = f_vjp.args[0].args[0][0]
else:
input_gradients = f_vjp.args[0].args[0][1]
Expand Down
Loading