Skip to content

Commit ccbc9d4

Browse files
authored
Update Torch and Tensorflow versions in cuda requirements files. (#21732)
Replacement for #21704 Also: - Disabled an ONNX export test for Torch that was already disabled on GPU with both JAX and Tensorflow. - Moved install for `tf_keras` from `requirements.txt` to `action.yml` using the `--no-deps` option because `tf_keras` depends on `tensorflow`, which installs the non-CPU version of TensorFlow and causes issues with CPU tests.
1 parent 6f7e893 commit ccbc9d4

File tree

6 files changed

+16
-12
lines changed

6 files changed

+16
-12
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ jobs:
5959
if [ "${{ matrix.nnx_enabled }}" == "true" ]; then
6060
pip install --upgrade flax>=0.11.1
6161
fi
62+
pip install --no-deps tf_keras==2.18.0
6263
pip uninstall -y keras keras-nightly
6364
pip install -e "." --progress-bar off --upgrade
6465
- name: Test applications with pytest

keras/src/export/onnx_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None):
7777
"backends."
7878
),
7979
)
80-
@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI")
8180
@pytest.mark.skipif(
82-
testing.tensorflow_uses_gpu(), reason="Leads to core dumps on CI"
81+
testing.jax_uses_gpu()
82+
or testing.tensorflow_uses_gpu()
83+
or testing.torch_uses_gpu(),
84+
reason="Fails on GPU",
8385
)
8486
class ExportONNXTest(testing.TestCase):
8587
@parameterized.named_parameters(

requirements-jax-cuda.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Tensorflow cpu-only version (needed for testing).
2-
tensorflow-cpu~=2.18.1
2+
tensorflow-cpu~=2.20.0
33
tf2onnx
44

55
# Torch cpu-only version (needed for testing).
66
--extra-index-url https://download.pytorch.org/whl/cpu
7-
torch==2.6.0
7+
torch==2.8.0
88

99
# Jax with cuda support.
1010
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

requirements-tensorflow-cuda.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Tensorflow with cuda support.
2-
tensorflow[and-cuda]~=2.18.1
2+
tensorflow[and-cuda]~=2.20.0
33
tf2onnx
44

55
# Torch cpu-only version (needed for testing).
66
--extra-index-url https://download.pytorch.org/whl/cpu
7-
torch==2.6.0
7+
torch==2.8.0
88

99
# Jax cpu-only version (needed for testing).
1010
jax[cpu]

requirements-torch-cuda.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Tensorflow cpu-only version (needed for testing).
2-
tensorflow-cpu~=2.18.1
2+
tensorflow-cpu~=2.20.0
33
tf2onnx
44

55
# Torch with cuda support.
66
# - torch is pinned to a version that is compatible with torch-xla.
77
--extra-index-url https://download.pytorch.org/whl/cu121
8-
torch==2.6.0
9-
torch-xla==2.6.0;sys_platform != 'darwin'
8+
torch==2.8.0
9+
torch-xla==2.8.1;sys_platform != 'darwin'
1010

1111
# Jax cpu-only version (needed for testing).
1212
jax[cpu]

requirements.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
# Tensorflow.
2+
# Note: when the version of Tensorflow is changed, the version tf_keras must be
3+
# changed in .github/workflows/actions.yml (pip install --no-deps tf_keras).
24
tensorflow-cpu~=2.18.1;sys_platform != 'darwin'
35
tensorflow~=2.18.1;sys_platform == 'darwin'
4-
tf_keras
56
tf2onnx
67

78
# Torch.
89
--extra-index-url https://download.pytorch.org/whl/cpu
9-
torch==2.6.0;sys_platform != 'darwin'
10-
torch==2.6.0;sys_platform == 'darwin'
10+
torch==2.6.0
1111
torch-xla==2.6.0;sys_platform != 'darwin'
1212

1313
# Jax.
1414
# Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test.
1515
# Note that we test against the latest JAX on GPU.
1616
jax[cpu]==0.5.0
1717
flax
18+
1819
# Common deps.
1920
-r requirements-common.txt

0 commit comments

Comments
 (0)