Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
1d7c685
added requirements-tensorflow-tpu.txt and tpu configuration in .kokoro
kharshith-k Jun 16, 2025
19b5e6b
updated .kokoro/github/ubuntu/tpu/build.sh with jax and torch backend…
kharshith-k Jun 16, 2025
d203ca3
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 18, 2025
f45e5d0
Changed the tpu CI config files path to .github from .kokoro
kharshith-k Jun 18, 2025
6771cc0
Added new job in .github/workflows/actions.yml to run TPU tests
kharshith-k Jun 18, 2025
87d36e7
fixed runs-on option in acvtions.yml for tpu_build job to run on self…
kharshith-k Jun 18, 2025
9901298
Added another runner in the actions TPU job
kharshith-k Jun 18, 2025
be97210
Update continuous.cfg
kharshith-k Jun 18, 2025
a1cd5c3
Update presubmit.cfg
kharshith-k Jun 18, 2025
c5e3a5c
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 23, 2025
f0ab676
Update actions.yml
kharshith-k Jun 23, 2025
09161d7
Developed Dockerfile for TPU build job in actions.yml
kharshith-k Jun 24, 2025
9a3948f
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 24, 2025
058fdff
Update actions.yml
kharshith-k Jun 24, 2025
d47e39e
Included few more runners in tpu_build job
kharshith-k Jun 26, 2025
a6a59d7
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jun 26, 2025
ba4f6ae
Using linux-x86-ct6e-44-1tpu
kharshith-k Jun 26, 2025
a5a3624
Modified requirement-commmon.txt and updated requirements-tensorflow-…
kharshith-k Jun 30, 2025
b9998af
Added Dtypes_TPU_tests.py and requirements-jax-tpu.txt
kharshith-k Jul 22, 2025
f68be97
Progress bar now handles `steps_per_execution`. (#21422)
hertschuh Jun 26, 2025
1018abf
Fix symbolic call of `logsumexp` with int axis. (#21428)
hertschuh Jun 27, 2025
0da77e4
Only allow deserialization of `KerasSaveable`s by module and name. (#…
hertschuh Jun 29, 2025
cb639c5
commented tensorflow deps
kharshith-k Jul 2, 2025
c0d1743
Added log of dtypes_test_tpu.py and the test script for the same
kharshith-k Jul 2, 2025
306e6e7
modified dtypes_test_tpu.py as per pre-commit standards
kharshith-k Jul 2, 2025
4e584fc
Added TPU initiaization and teardown functionalities in conftest.py, …
kharshith-k Jul 3, 2025
bb09e95
Added dtypes_test_TPU.py and dtypes_new_test.py, modified conftest.py
kharshith-k Jul 9, 2025
8a63d09
Added Dcokerfile and tests list command
kharshith-k Jul 23, 2025
4651454
Updated Dockerfile
kharshith-k Jul 28, 2025
40af241
Restored Dockerfile to previous changes
kharshith-k Jul 28, 2025
64420d5
updated actions.yml file to install and configure docker engine on se…
kharshith-k Jul 28, 2025
da84de5
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jul 28, 2025
d69277d
updated actions.yml file to include container option
kharshith-k Jul 28, 2025
1c307fc
updated actions.yml file to include container option without volume b…
kharshith-k Jul 28, 2025
693886b
updated actions.yml file to change TPU
kharshith-k Jul 28, 2025
e74b851
Updated container path in build-and-test-on-tpu job
kharshith-k Jul 29, 2025
d31b3c4
seperated TPU workflow from actions.yml
kharshith-k Jul 29, 2025
a70d19e
updated trigger condition for TPU tests workflow
kharshith-k Jul 29, 2025
5f5b609
updated container usage configuration for TPU tests workflow
kharshith-k Jul 29, 2025
72e729f
updated env vars for TPU tests workflow
kharshith-k Jul 29, 2025
e129299
updated env vars parsing syntax in TPU tests workflow
kharshith-k Jul 29, 2025
3fe5b57
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
10df307
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
dd21e09
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
328628f
updated env vars syntax in TPU tests workflow
kharshith-k Jul 29, 2025
01f0c17
updated image name in TPU tests workflow
kharshith-k Jul 29, 2025
3e41c37
updated image name with generic ubuntu image
kharshith-k Jul 29, 2025
5e55c2c
updated tpu-tests to use ghcr
kharshith-k Jul 29, 2025
ea9ff88
updated tpu-tests to store built image as local tar
kharshith-k Jul 29, 2025
6d92aa9
updated image name from ubuntu:22.04 to docker:24.0-cli in tpu tests …
kharshith-k Jul 29, 2025
3c75bf8
updated image name from docker:24.0-cli to ubuntu:22.04 in tpu tests…
kharshith-k Jul 29, 2025
1589a75
added volume mount from host in load-and-test-job
kharshith-k Jul 29, 2025
36bd682
Merge branch 'keras-team:master' into tf-tpu
kharshith-k Jul 29, 2025
04112cf
Reverted tpu-tests.yml to version using ghcr.io for image storage
kharshith-k Jul 29, 2025
87d7ad8
Removed custom dtypes_test files for TPU testing and restored origina…
kharshith-k Aug 12, 2025
6cb097c
Updated tpu-tests.yml to pull image from GCP artifact registry
kharshith-k Aug 12, 2025
4829f1b
Resolved conflicts in actions.yml
kharshith-k Aug 12, 2025
a2eb306
Added a workflow to check service accounts associated with self hoste…
kharshith-k Aug 12, 2025
23579c4
Made find_sa.yml specific to linux-x86-ct6e-44-1tpu
kharshith-k Aug 12, 2025
dac6433
Added container tag to find_sa.yml
kharshith-k Aug 12, 2025
05461c1
Checking SA for linux-x86-ct5lp-112-4tpu
kharshith-k Aug 12, 2025
078dcee
Checking SA for linux-x86-ct6e-44-1tpu-nxgm7-runner-vb87c
kharshith-k Aug 12, 2025
016c68d
Using SA for auth in tpu-tests
kharshith-k Aug 12, 2025
02657f0
Updated SA with container tag for auth in tpu-tests
kharshith-k Aug 12, 2025
7167952
Added docker socket mount test
kharshith-k Aug 12, 2025
543cf65
Updated tpu-tests to just pull and test the image from artifact regis…
kharshith-k Aug 14, 2025
a2401c0
Added pytest command to the workflow
kharshith-k Aug 14, 2025
a98c748
added grain installation command
kharshith-k Aug 14, 2025
71c5b8b
Pruned unwanted files
kharshith-k Aug 19, 2025
5522a2b
included grain in requirements.txt
kharshith-k Aug 19, 2025
7173c6d
Updated tpu-tests.yml to use python image and explicitly install spec…
kharshith-k Aug 22, 2025
a7dc789
Renamed tpu-tests to tpu-tests-jax and logging TPU device kind
kharshith-k Aug 22, 2025
e509e6d
Added a step to check gcloud installation
kharshith-k Aug 22, 2025
a7ec63b
Running pytest on generic tpu workflow
kharshith-k Aug 25, 2025
dfa7bc2
Made changes as per suggestions in PR
kharshith-k Nov 3, 2025
ad9e073
Merge branch 'master' into tf-tpu
kharshith-k Nov 3, 2025
03743bd
Fixed error in action file
kharshith-k Nov 3, 2025
babb216
Added a job in tpu workflow to persist failed tests list
kharshith-k Nov 3, 2025
c766b69
using requirements-jax-tpu.txt
kharshith-k Nov 3, 2025
f2f1c9c
Reverted the tpu-tests-jax.yml
kharshith-k Nov 3, 2025
201cefe
Removed a command line option in tpu workflow file
kharshith-k Nov 3, 2025
cf8f29a
Removed uninstall step from tpu workflow job
kharshith-k Nov 3, 2025
2840f3f
reverted tensorflow version in requirements file
kharshith-k Nov 3, 2025
9edff16
Updated the tpu workflow to skip failing tests
kharshith-k Nov 4, 2025
82ddbad
Clean up TPU tests workflow by removing comment
kharshith-k Nov 4, 2025
597f27d
Changed the failed tests file path and updated the same in conftest.py
kharshith-k Nov 4, 2025
356b62b
Updated the failed test file path
kharshith-k Nov 4, 2025
4e0d3af
Updated the workflow and job names for TPU
kharshith-k Nov 5, 2025
0d0b009
Added TPU job in actions.yml
kharshith-k Nov 6, 2025
65252a4
Added TPU config print line
kharshith-k Nov 6, 2025
1769e32
Added condition to TPU test job so that it gets triggered only after …
kharshith-k Nov 11, 2025
c44ce18
Added TPU specific tests in seperate workflow with PR approval condition
kharshith-k Nov 12, 2025
4b84da1
Removed pull_request condition for workflow execution and renamed the…
kharshith-k Nov 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,4 @@ jobs:
pip uninstall -y keras keras-nightly
pip install -e "." --progress-bar off --upgrade
- name: Run pre-commit
run: pre-commit run --all-files --hook-stage manual
run: pre-commit run --all-files --hook-stage manual
56 changes: 56 additions & 0 deletions .github/workflows/tpu_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: Keras Tests

# TODO: Consider enabling all tests (pytest, applications, etc.) with NNX in the future
# Currently only basic flow tests run with NNX enabled

on:
push:
branches: [ master ]
pull_request_review:
types: [submitted]
release:
types: [created]

permissions:
contents: read

jobs:

test-in-container:
name: Run tests on TPU
runs-on: linux-x86-ct6e-44-1tpu
# Only run on approved PRs, pushes to master, or releases
if: |
github.event_name == 'push' ||
github.event_name == 'release' ||
(github.event_name == 'pull_request_review' && github.event.review.state == 'approved')

strategy:
fail-fast: false
matrix:
backend: [jax]

container:
image: python:3.10-slim
options: --privileged --network host

steps:
- name: Checkout Repository
uses: actions/checkout@v4

- name: Install Dependencies
run: |
pip install --no-cache-dir -r requirements-${{ matrix.backend }}-tpu.txt \

- name: Set Keras Backend
run: echo "KERAS_BACKEND=jax" >> $GITHUB_ENV

- name: Run Verification and Tests
run: |
echo "Successfully running inside the public python container!"
echo "Verifying JAX installation..."
python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices : {jax.devices()}')"

pytest keras --ignore keras/src/applications \
--cov=keras \
--cov-config=pyproject.toml
23 changes: 23 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ def pytest_collection_modifyitems(config, items):
line.strip() for line in openvino_skipped_tests if line.strip()
]

tpu_skipped_tests = []
if backend() == "jax":
try:
with open(
"keras/src/backend/jax/excluded_tpu_tests.txt", "r"
) as file:
tpu_skipped_tests = file.readlines()
# it is necessary to check if stripped line is not empty
# and exclude such lines
tpu_skipped_tests = [
line.strip() for line in tpu_skipped_tests if line.strip()
]
except FileNotFoundError:
pass # File doesn't exist, no tests to skip

requires_trainable_backend = pytest.mark.skipif(
backend() in ["numpy", "openvino"],
reason="Trainer not implemented for NumPy and OpenVINO backend.",
Expand All @@ -49,6 +64,14 @@ def pytest_collection_modifyitems(config, items):
"Not supported operation by openvino backend",
)
)
# also, skip concrete tests for TPU when using JAX backend
for skipped_test in tpu_skipped_tests:
if skipped_test in item.nodeid:
item.add_marker(
pytest.mark.skip(
reason="Known TPU test failure",
)
)


def skip_if_backend(given_backend, reason):
Expand Down
234 changes: 234 additions & 0 deletions keras/src/backend/jax/excluded_tpu_tests.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
AdditiveAttentionTest::test_attention_correctness
AttentionTest::test_attention_calculate_scores_with_scale
AttentionTest::test_attention_correctness
CircleTest::test_correctness
CircleTest::test_correctness_weighted
CircleTest::test_mean_with_sample_weight_reduction
CircleTest::test_no_reduction
CircleTest::test_sum_reduction
ConvBasicTest::test_enable_lora_with_alpha
ConvCorrectnessTest::test_conv1d0
ConvCorrectnessTest::test_conv1d1
ConvCorrectnessTest::test_conv1d2
ConvCorrectnessTest::test_conv1d3
ConvCorrectnessTest::test_conv1d4
ConvCorrectnessTest::test_conv2d0
ConvCorrectnessTest::test_conv2d1
ConvCorrectnessTest::test_conv2d2
ConvCorrectnessTest::test_conv2d3
ConvCorrectnessTest::test_conv2d4
ConvCorrectnessTest::test_conv2d5
ConvCorrectnessTest::test_conv3d0
ConvCorrectnessTest::test_conv3d1
ConvCorrectnessTest::test_conv3d2
ConvCorrectnessTest::test_conv3d3
ConvCorrectnessTest::test_conv3d4
ConvLSTM1DTest::test_correctness
ConvLSTM1DTest::test_correctness
ConvLSTM2DTest::test_correctness
ConvLSTMCellTest::test_correctness
ConvLSTMTest::test_correctness
ConvTransposeCorrectnessTest::test_conv1d_transpose0
ConvTransposeCorrectnessTest::test_conv1d_transpose1
ConvTransposeCorrectnessTest::test_conv1d_transpose2
ConvTransposeCorrectnessTest::test_conv2d_transpose0
ConvTransposeCorrectnessTest::test_conv2d_transpose1
ConvTransposeCorrectnessTest::test_conv2d_transpose2
ConvTransposeCorrectnessTest::test_conv2d_transpose3
ConvTransposeCorrectnessTest::test_conv3d_transpose0
ConvTransposeCorrectnessTest::test_conv3d_transpose1
ConvTransposeCorrectnessTest::test_conv3d_transpose2
CTCTest::test_correctness
DenseTest::test_dense_sparse
DepthwiseConvCorrectnessTest::test_depthwise_conv1d0
DepthwiseConvCorrectnessTest::test_depthwise_conv1d1
DepthwiseConvCorrectnessTest::test_depthwise_conv1d2
DepthwiseConvCorrectnessTest::test_depthwise_conv2d0
DepthwiseConvCorrectnessTest::test_depthwise_conv2d1
DepthwiseConvCorrectnessTest::test_depthwise_conv2d2
EinsumDenseTest::test_enable_lora_with_alpha
EmbeddingTest::test_enable_lora_with_alpha
ExportArchiveTest::test_jax_endpoint_registration_tf_function
ExportArchiveTest::test_jax_multi_unknown_endpoint_registration
ExportArchiveTest::test_layer_export
ExportArchiveTest::test_low_level_model_export_functional
ExportArchiveTest::test_low_level_model_export_sequential
ExportArchiveTest::test_low_level_model_export_subclass
ExportArchiveTest::test_low_level_model_export_with_alias
ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_functional
ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_sequential
ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_subclass
ExportArchiveTest::test_low_level_model_export_with_jax2tf_kwargs
ExportArchiveTest::test_low_level_model_export_with_jax2tf_polymorphic_shapes
ExportArchiveTest::test_model_combined_with_tf_preprocessing
ExportArchiveTest::test_model_export_method_functional
ExportArchiveTest::test_model_export_method_sequential
ExportArchiveTest::test_model_export_method_subclass
ExportArchiveTest::test_multi_input_output_functional_model
ExportArchiveTest::test_non_standard_layer_signature
ExportArchiveTest::test_non_standard_layer_signature_with_kwargs
ExportArchiveTest::test_track_multiple_layers
ExportONNXTest::test_export_with_input_names
ExportONNXTest::test_export_with_opset_version_18
ExportONNXTest::test_export_with_opset_version_none
ExportONNXTest::test_standard_model_export_functional
ExportONNXTest::test_standard_model_export_lstm
ExportONNXTest::test_standard_model_export_sequential
ExportONNXTest::test_standard_model_export_subclass
ExportOpenVINOTest::test_standard_model_export_functional
ExportOpenVINOTest::test_standard_model_export_sequential
ExportOpenVINOTest::test_standard_model_export_subclass
ExportSavedModelTest::test_input_signature_functional_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>
ExportSavedModelTest::test_input_signature_functional_backend_tensor
ExportSavedModelTest::test_input_signature_functional_inputspec(dtype=float32, shape=(none, 10), ndim=2)
ExportSavedModelTest::test_input_signature_functional_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')
ExportSavedModelTest::test_input_signature_sequential_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>
ExportSavedModelTest::test_input_signature_sequential_backend_tensor
ExportSavedModelTest::test_input_signature_sequential_inputspec(dtype=float32, shape=(none, 10), ndim=2)
ExportSavedModelTest::test_input_signature_sequential_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')
ExportSavedModelTest::test_input_signature_subclass_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>
ExportSavedModelTest::test_input_signature_subclass_backend_tensor
ExportSavedModelTest::test_input_signature_subclass_inputspec(dtype=float32, shape=(none, 10), ndim=2)
ExportSavedModelTest::test_input_signature_subclass_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')
ExportSavedModelTest::test_jax_specific_kwargs_functional_false_{'enable_xla': true, 'native_serialization': true}
ExportSavedModelTest::test_jax_specific_kwargs_functional_false_none
ExportSavedModelTest::test_jax_specific_kwargs_functional_true_{'enable_xla': true, 'native_serialization': true}
ExportSavedModelTest::test_jax_specific_kwargs_functional_true_none
ExportSavedModelTest::test_jax_specific_kwargs_sequential_false_{'enable_xla': true, 'native_serialization': true}
ExportSavedModelTest::test_jax_specific_kwargs_sequential_false_none
ExportSavedModelTest::test_jax_specific_kwargs_sequential_true_{'enable_xla': true, 'native_serialization': true}
ExportSavedModelTest::test_jax_specific_kwargs_sequential_true_none
ExportSavedModelTest::test_jax_specific_kwargs_subclass_false_{'enable_xla': true, 'native_serialization': true}
ExportSavedModelTest::test_jax_specific_kwargs_subclass_false_none
ExportSavedModelTest::test_jax_specific_kwargs_subclass_true_{'enable_xla': true, 'native_serialization': true}
ExportSavedModelTest::test_jax_specific_kwargs_subclass_true_none
ExportSavedModelTest::test_model_with_input_structure_array
ExportSavedModelTest::test_model_with_input_structure_dict
ExportSavedModelTest::test_model_with_input_structure_tuple
ExportSavedModelTest::test_model_with_multiple_inputs
ExportSavedModelTest::test_model_with_non_trainable_state_export_functional
ExportSavedModelTest::test_model_with_non_trainable_state_export_sequential
ExportSavedModelTest::test_model_with_non_trainable_state_export_subclass
ExportSavedModelTest::test_model_with_rng_export_functional
ExportSavedModelTest::test_model_with_rng_export_sequential
ExportSavedModelTest::test_model_with_rng_export_subclass
ExportSavedModelTest::test_model_with_tf_data_layer_functional
ExportSavedModelTest::test_model_with_tf_data_layer_sequential
ExportSavedModelTest::test_model_with_tf_data_layer_subclass
ExportSavedModelTest::test_standard_model_export_functional
ExportSavedModelTest::test_standard_model_export_sequential
ExportSavedModelTest::test_standard_model_export_subclass
GRUTest::test_correctness0
GRUTest::test_correctness1
GRUTest::test_legacy_implementation_argument
GRUTest::test_masking
GRUTest::test_pass_initial_state
GRUTest::test_pass_return_state
GRUTest::test_statefulness
ImageOpsCorrectnessTest::test_affine_transform_bilinear_constant
ImageOpsCorrectnessTest::test_affine_transform_bilinear_mirror
ImageOpsCorrectnessTest::test_affine_transform_bilinear_nearest
ImageOpsCorrectnessTest::test_affine_transform_bilinear_reflect
ImageOpsCorrectnessTest::test_affine_transform_bilinear_wrap
LinalgOpsCorrectnessTest::test_cholesky_inverse_lower
LinalgOpsCorrectnessTest::test_cholesky_inverse_upper
LinalgOpsCorrectnessTest::test_eig
LinalgOpsCorrectnessTest::test_svd
LSTMTest::test_correctness0
LSTMTest::test_correctness1
LSTMTest::test_masking
LSTMTest::test_pass_initial_state
LSTMTest::test_statefulness
MathOpsCorrectnessTest::test_extract_sequences
MergingLayersTest::test_correctness_dynamic_dot_3d
MergingLayersTest::test_correctness_static_dot_3d
MuonTest::test_Newton_Schulz
NNOpsCorrectnessTest::test_conv_2d0
NNOpsCorrectnessTest::test_conv_2d1
NNOpsCorrectnessTest::test_conv_2d2
NNOpsCorrectnessTest::test_conv_2d3
NNOpsCorrectnessTest::test_conv_2d4
NNOpsCorrectnessTest::test_conv_2d5
NNOpsCorrectnessTest::test_conv_3d0
NNOpsCorrectnessTest::test_conv_3d1
NNOpsCorrectnessTest::test_conv_3d10
NNOpsCorrectnessTest::test_conv_3d11
NNOpsCorrectnessTest::test_conv_3d2
NNOpsCorrectnessTest::test_conv_3d3
NNOpsCorrectnessTest::test_conv_3d4
NNOpsCorrectnessTest::test_conv_3d5
NNOpsCorrectnessTest::test_conv_3d6
NNOpsCorrectnessTest::test_conv_3d7
NNOpsCorrectnessTest::test_conv_3d8
NNOpsCorrectnessTest::test_conv_3d9
NNOpsCorrectnessTest::test_ctc_loss
NNOpsCorrectnessTest::test_depthwise_conv_2d0
NNOpsCorrectnessTest::test_depthwise_conv_2d1
NNOpsCorrectnessTest::test_depthwise_conv_2d10
NNOpsCorrectnessTest::test_depthwise_conv_2d11
NNOpsCorrectnessTest::test_depthwise_conv_2d2
NNOpsCorrectnessTest::test_depthwise_conv_2d3
NNOpsCorrectnessTest::test_depthwise_conv_2d4
NNOpsCorrectnessTest::test_depthwise_conv_2d5
NNOpsCorrectnessTest::test_depthwise_conv_2d6
NNOpsCorrectnessTest::test_depthwise_conv_2d7
NNOpsCorrectnessTest::test_depthwise_conv_2d8
NNOpsCorrectnessTest::test_depthwise_conv_2d9
NNOpsCorrectnessTest::test_separable_conv_2d0
NNOpsCorrectnessTest::test_separable_conv_2d1
NNOpsCorrectnessTest::test_separable_conv_2d2
NNOpsCorrectnessTest::test_separable_conv_2d3
NNOpsCorrectnessTest::test_separable_conv_2d4
NNOpsCorrectnessTest::test_separable_conv_2d5
NNOpsCorrectnessTest::test_separable_conv_2d6
NNOpsCorrectnessTest::test_separable_conv_2d7
NumpyOneInputOpsDynamicShapeTest::test_argmax_negative_zero
NumpyOneInputOpsDynamicShapeTest::test_argmin_negative_zero
NumpyTwoInputOpsCorrectnessTest::test_logspace
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank2_float32_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank2_float64_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float16_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float32_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float64_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float16_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float32_false_false
NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float64_false_false
RandomGaussianBlurTest::test_random_erasing_basic
RandomPerspectiveTest::test_random_perspective_bounding_boxes_with_large_scale
RandomPerspectiveTest::test_random_perspective_bounding_boxes_with_small_scale
RandomZoomTest::test_random_zoom_out_correctness
RegularizersTest::test_orthogonal_regularizer
RNNTest::test_go_backwards
SeparableConvCorrectnessTest::test_separable_conv1d0
SeparableConvCorrectnessTest::test_separable_conv1d1
SeparableConvCorrectnessTest::test_separable_conv1d2
SeparableConvCorrectnessTest::test_separable_conv2d0
SeparableConvCorrectnessTest::test_separable_conv2d1
SeparableConvCorrectnessTest::test_separable_conv2d2
SimpleRNNTest::test_correctness
SimpleRNNTest::test_correctness
SimpleRNNTest::test_masking
SimpleRNNTest::test_masking
SimpleRNNTest::test_pass_initial_state
SimpleRNNTest::test_pass_initial_state
SimpleRNNTest::test_return_state
SimpleRNNTest::test_statefulness
SimpleRNNTest::test_statefulness
StackedRNNTest::test_correctness_single_state_stack
StackedRNNTest::test_correctness_two_states_stack
StackedRNNTest::test_statefullness_single_state_stack
StackedRNNTest::test_statefullness_two_states_stack
TestFitLRSchedulesFlow::test_fit_lr_correctness
TestJaxLayer::test_flax_layer_training_independent_bound_method
TestJaxLayer::test_flax_layer_training_rng_state_no_method
TestJaxLayer::test_flax_layer_training_rng_unbound_method
TestJaxLayer::test_flax_layer_training_rng_unbound_method_dtype_policy
TestJaxLayer::test_jax_layer_training_independent
TestJaxLayer::test_jax_layer_training_state
TestJaxLayer::test_jax_layer_training_state_dtype_policy
TestSpectrogram::test_spectrogram_error
TestTrainer::test_loss_weights
TestTrainer::test_nested_inputs
TestTrainer::test_on_batch_methods_eager
TestTrainer::test_on_batch_methods_graph_fn
TestTrainer::test_on_batch_methods_jit
1 change: 0 additions & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,4 @@ onnxruntime
# onnxscript==0.3.2 breaks LSTM model export.
onnxscript!=0.3.2
openvino
# for grain_dataset_adapter_test.py
grain
14 changes: 14 additions & 0 deletions requirements-jax-tpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Tensorflow cpu-only version (needed for testing).
tensorflow-cpu~=2.18.1
tf2onnx

# Torch cpu-only version (needed for testing).
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.6.0

# Jax with cuda support.
--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
jax[tpu]
flax

-r requirements-common.txt
13 changes: 13 additions & 0 deletions requirements-tensorflow-tpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
--find-links https://storage.googleapis.com/libtpu-tf-releases/index.html
tensorflow-tpu==2.19.1

tf2onnx

# Torch cpu-only version (needed for testing).
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.6.0

# Jax cpu-only version (needed for testing).
jax

-r requirements-common.txt