diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index f4a27394247c..d22d92209071 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -148,4 +148,4 @@ jobs: pip uninstall -y keras keras-nightly pip install -e "." --progress-bar off --upgrade - name: Run pre-commit - run: pre-commit run --all-files --hook-stage manual + run: pre-commit run --all-files --hook-stage manual \ No newline at end of file diff --git a/.github/workflows/tpu_tests.yml b/.github/workflows/tpu_tests.yml new file mode 100644 index 000000000000..523b89971454 --- /dev/null +++ b/.github/workflows/tpu_tests.yml @@ -0,0 +1,56 @@ +name: Keras Tests + +# TODO: Consider enabling all tests (pytest, applications, etc.) with NNX in the future +# Currently only basic flow tests run with NNX enabled + +on: + push: + branches: [ master ] + pull_request_review: + types: [submitted] + release: + types: [created] + +permissions: + contents: read + +jobs: + + test-in-container: + name: Run tests on TPU + runs-on: linux-x86-ct6e-44-1tpu + # Only run on approved PRs, pushes to master, or releases + if: | + github.event_name == 'push' || + github.event_name == 'release' || + (github.event_name == 'pull_request_review' && github.event.review.state == 'approved') + + strategy: + fail-fast: false + matrix: + backend: [jax] + + container: + image: python:3.10-slim + options: --privileged --network host + + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + + - name: Install Dependencies + run: | + pip install --no-cache-dir -r requirements-${{ matrix.backend }}-tpu.txt \ + + - name: Set Keras Backend + run: echo "KERAS_BACKEND=jax" >> $GITHUB_ENV + + - name: Run Verification and Tests + run: | + echo "Successfully running inside the public python container!" + echo "Verifying JAX installation..." + python3 -c "import jax; print(f'JAX backend: {jax.default_backend()}'); print(f'JAX devices : {jax.devices()}')" + + pytest keras --ignore keras/src/applications \ + --cov=keras \ + --cov-config=pyproject.toml diff --git a/conftest.py b/conftest.py index 9853ff86baf1..55ba6832ba6f 100644 --- a/conftest.py +++ b/conftest.py @@ -31,6 +31,21 @@ def pytest_collection_modifyitems(config, items): line.strip() for line in openvino_skipped_tests if line.strip() ] + tpu_skipped_tests = [] + if backend() == "jax": + try: + with open( + "keras/src/backend/jax/excluded_tpu_tests.txt", "r" + ) as file: + tpu_skipped_tests = file.readlines() + # it is necessary to check if stripped line is not empty + # and exclude such lines + tpu_skipped_tests = [ + line.strip() for line in tpu_skipped_tests if line.strip() + ] + except FileNotFoundError: + pass # File doesn't exist, no tests to skip + requires_trainable_backend = pytest.mark.skipif( backend() in ["numpy", "openvino"], reason="Trainer not implemented for NumPy and OpenVINO backend.", @@ -49,6 +64,14 @@ def pytest_collection_modifyitems(config, items): "Not supported operation by openvino backend", ) ) + # also, skip concrete tests for TPU when using JAX backend + for skipped_test in tpu_skipped_tests: + if skipped_test in item.nodeid: + item.add_marker( + pytest.mark.skip( + reason="Known TPU test failure", + ) + ) def skip_if_backend(given_backend, reason): diff --git a/keras/src/backend/jax/excluded_tpu_tests.txt b/keras/src/backend/jax/excluded_tpu_tests.txt new file mode 100644 index 000000000000..13a7b799aca1 --- /dev/null +++ b/keras/src/backend/jax/excluded_tpu_tests.txt @@ -0,0 +1,234 @@ +AdditiveAttentionTest::test_attention_correctness +AttentionTest::test_attention_calculate_scores_with_scale +AttentionTest::test_attention_correctness +CircleTest::test_correctness +CircleTest::test_correctness_weighted +CircleTest::test_mean_with_sample_weight_reduction +CircleTest::test_no_reduction +CircleTest::test_sum_reduction +ConvBasicTest::test_enable_lora_with_alpha +ConvCorrectnessTest::test_conv1d0 +ConvCorrectnessTest::test_conv1d1 +ConvCorrectnessTest::test_conv1d2 +ConvCorrectnessTest::test_conv1d3 +ConvCorrectnessTest::test_conv1d4 +ConvCorrectnessTest::test_conv2d0 +ConvCorrectnessTest::test_conv2d1 +ConvCorrectnessTest::test_conv2d2 +ConvCorrectnessTest::test_conv2d3 +ConvCorrectnessTest::test_conv2d4 +ConvCorrectnessTest::test_conv2d5 +ConvCorrectnessTest::test_conv3d0 +ConvCorrectnessTest::test_conv3d1 +ConvCorrectnessTest::test_conv3d2 +ConvCorrectnessTest::test_conv3d3 +ConvCorrectnessTest::test_conv3d4 +ConvLSTM1DTest::test_correctness +ConvLSTM1DTest::test_correctness +ConvLSTM2DTest::test_correctness +ConvLSTMCellTest::test_correctness +ConvLSTMTest::test_correctness +ConvTransposeCorrectnessTest::test_conv1d_transpose0 +ConvTransposeCorrectnessTest::test_conv1d_transpose1 +ConvTransposeCorrectnessTest::test_conv1d_transpose2 +ConvTransposeCorrectnessTest::test_conv2d_transpose0 +ConvTransposeCorrectnessTest::test_conv2d_transpose1 +ConvTransposeCorrectnessTest::test_conv2d_transpose2 +ConvTransposeCorrectnessTest::test_conv2d_transpose3 +ConvTransposeCorrectnessTest::test_conv3d_transpose0 +ConvTransposeCorrectnessTest::test_conv3d_transpose1 +ConvTransposeCorrectnessTest::test_conv3d_transpose2 +CTCTest::test_correctness +DenseTest::test_dense_sparse +DepthwiseConvCorrectnessTest::test_depthwise_conv1d0 +DepthwiseConvCorrectnessTest::test_depthwise_conv1d1 +DepthwiseConvCorrectnessTest::test_depthwise_conv1d2 +DepthwiseConvCorrectnessTest::test_depthwise_conv2d0 +DepthwiseConvCorrectnessTest::test_depthwise_conv2d1 +DepthwiseConvCorrectnessTest::test_depthwise_conv2d2 +EinsumDenseTest::test_enable_lora_with_alpha +EmbeddingTest::test_enable_lora_with_alpha +ExportArchiveTest::test_jax_endpoint_registration_tf_function +ExportArchiveTest::test_jax_multi_unknown_endpoint_registration +ExportArchiveTest::test_layer_export +ExportArchiveTest::test_low_level_model_export_functional +ExportArchiveTest::test_low_level_model_export_sequential +ExportArchiveTest::test_low_level_model_export_subclass +ExportArchiveTest::test_low_level_model_export_with_alias +ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_functional +ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_sequential +ExportArchiveTest::test_low_level_model_export_with_dynamic_dims_subclass +ExportArchiveTest::test_low_level_model_export_with_jax2tf_kwargs +ExportArchiveTest::test_low_level_model_export_with_jax2tf_polymorphic_shapes +ExportArchiveTest::test_model_combined_with_tf_preprocessing +ExportArchiveTest::test_model_export_method_functional +ExportArchiveTest::test_model_export_method_sequential +ExportArchiveTest::test_model_export_method_subclass +ExportArchiveTest::test_multi_input_output_functional_model +ExportArchiveTest::test_non_standard_layer_signature +ExportArchiveTest::test_non_standard_layer_signature_with_kwargs +ExportArchiveTest::test_track_multiple_layers +ExportONNXTest::test_export_with_input_names +ExportONNXTest::test_export_with_opset_version_18 +ExportONNXTest::test_export_with_opset_version_none +ExportONNXTest::test_standard_model_export_functional +ExportONNXTest::test_standard_model_export_lstm +ExportONNXTest::test_standard_model_export_sequential +ExportONNXTest::test_standard_model_export_subclass +ExportOpenVINOTest::test_standard_model_export_functional +ExportOpenVINOTest::test_standard_model_export_sequential +ExportOpenVINOTest::test_standard_model_export_subclass +ExportSavedModelTest::test_input_signature_functional_ +ExportSavedModelTest::test_input_signature_functional_backend_tensor +ExportSavedModelTest::test_input_signature_functional_inputspec(dtype=float32, shape=(none, 10), ndim=2) +ExportSavedModelTest::test_input_signature_functional_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs') +ExportSavedModelTest::test_input_signature_sequential_ +ExportSavedModelTest::test_input_signature_sequential_backend_tensor +ExportSavedModelTest::test_input_signature_sequential_inputspec(dtype=float32, shape=(none, 10), ndim=2) +ExportSavedModelTest::test_input_signature_sequential_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs') +ExportSavedModelTest::test_input_signature_subclass_ +ExportSavedModelTest::test_input_signature_subclass_backend_tensor +ExportSavedModelTest::test_input_signature_subclass_inputspec(dtype=float32, shape=(none, 10), ndim=2) +ExportSavedModelTest::test_input_signature_subclass_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs') +ExportSavedModelTest::test_jax_specific_kwargs_functional_false_{'enable_xla': true, 'native_serialization': true} +ExportSavedModelTest::test_jax_specific_kwargs_functional_false_none +ExportSavedModelTest::test_jax_specific_kwargs_functional_true_{'enable_xla': true, 'native_serialization': true} +ExportSavedModelTest::test_jax_specific_kwargs_functional_true_none +ExportSavedModelTest::test_jax_specific_kwargs_sequential_false_{'enable_xla': true, 'native_serialization': true} +ExportSavedModelTest::test_jax_specific_kwargs_sequential_false_none +ExportSavedModelTest::test_jax_specific_kwargs_sequential_true_{'enable_xla': true, 'native_serialization': true} +ExportSavedModelTest::test_jax_specific_kwargs_sequential_true_none +ExportSavedModelTest::test_jax_specific_kwargs_subclass_false_{'enable_xla': true, 'native_serialization': true} +ExportSavedModelTest::test_jax_specific_kwargs_subclass_false_none +ExportSavedModelTest::test_jax_specific_kwargs_subclass_true_{'enable_xla': true, 'native_serialization': true} +ExportSavedModelTest::test_jax_specific_kwargs_subclass_true_none +ExportSavedModelTest::test_model_with_input_structure_array +ExportSavedModelTest::test_model_with_input_structure_dict +ExportSavedModelTest::test_model_with_input_structure_tuple +ExportSavedModelTest::test_model_with_multiple_inputs +ExportSavedModelTest::test_model_with_non_trainable_state_export_functional +ExportSavedModelTest::test_model_with_non_trainable_state_export_sequential +ExportSavedModelTest::test_model_with_non_trainable_state_export_subclass +ExportSavedModelTest::test_model_with_rng_export_functional +ExportSavedModelTest::test_model_with_rng_export_sequential +ExportSavedModelTest::test_model_with_rng_export_subclass +ExportSavedModelTest::test_model_with_tf_data_layer_functional +ExportSavedModelTest::test_model_with_tf_data_layer_sequential +ExportSavedModelTest::test_model_with_tf_data_layer_subclass +ExportSavedModelTest::test_standard_model_export_functional +ExportSavedModelTest::test_standard_model_export_sequential +ExportSavedModelTest::test_standard_model_export_subclass +GRUTest::test_correctness0 +GRUTest::test_correctness1 +GRUTest::test_legacy_implementation_argument +GRUTest::test_masking +GRUTest::test_pass_initial_state +GRUTest::test_pass_return_state +GRUTest::test_statefulness +ImageOpsCorrectnessTest::test_affine_transform_bilinear_constant +ImageOpsCorrectnessTest::test_affine_transform_bilinear_mirror +ImageOpsCorrectnessTest::test_affine_transform_bilinear_nearest +ImageOpsCorrectnessTest::test_affine_transform_bilinear_reflect +ImageOpsCorrectnessTest::test_affine_transform_bilinear_wrap +LinalgOpsCorrectnessTest::test_cholesky_inverse_lower +LinalgOpsCorrectnessTest::test_cholesky_inverse_upper +LinalgOpsCorrectnessTest::test_eig +LinalgOpsCorrectnessTest::test_svd +LSTMTest::test_correctness0 +LSTMTest::test_correctness1 +LSTMTest::test_masking +LSTMTest::test_pass_initial_state +LSTMTest::test_statefulness +MathOpsCorrectnessTest::test_extract_sequences +MergingLayersTest::test_correctness_dynamic_dot_3d +MergingLayersTest::test_correctness_static_dot_3d +MuonTest::test_Newton_Schulz +NNOpsCorrectnessTest::test_conv_2d0 +NNOpsCorrectnessTest::test_conv_2d1 +NNOpsCorrectnessTest::test_conv_2d2 +NNOpsCorrectnessTest::test_conv_2d3 +NNOpsCorrectnessTest::test_conv_2d4 +NNOpsCorrectnessTest::test_conv_2d5 +NNOpsCorrectnessTest::test_conv_3d0 +NNOpsCorrectnessTest::test_conv_3d1 +NNOpsCorrectnessTest::test_conv_3d10 +NNOpsCorrectnessTest::test_conv_3d11 +NNOpsCorrectnessTest::test_conv_3d2 +NNOpsCorrectnessTest::test_conv_3d3 +NNOpsCorrectnessTest::test_conv_3d4 +NNOpsCorrectnessTest::test_conv_3d5 +NNOpsCorrectnessTest::test_conv_3d6 +NNOpsCorrectnessTest::test_conv_3d7 +NNOpsCorrectnessTest::test_conv_3d8 +NNOpsCorrectnessTest::test_conv_3d9 +NNOpsCorrectnessTest::test_ctc_loss +NNOpsCorrectnessTest::test_depthwise_conv_2d0 +NNOpsCorrectnessTest::test_depthwise_conv_2d1 +NNOpsCorrectnessTest::test_depthwise_conv_2d10 +NNOpsCorrectnessTest::test_depthwise_conv_2d11 +NNOpsCorrectnessTest::test_depthwise_conv_2d2 +NNOpsCorrectnessTest::test_depthwise_conv_2d3 +NNOpsCorrectnessTest::test_depthwise_conv_2d4 +NNOpsCorrectnessTest::test_depthwise_conv_2d5 +NNOpsCorrectnessTest::test_depthwise_conv_2d6 +NNOpsCorrectnessTest::test_depthwise_conv_2d7 +NNOpsCorrectnessTest::test_depthwise_conv_2d8 +NNOpsCorrectnessTest::test_depthwise_conv_2d9 +NNOpsCorrectnessTest::test_separable_conv_2d0 +NNOpsCorrectnessTest::test_separable_conv_2d1 +NNOpsCorrectnessTest::test_separable_conv_2d2 +NNOpsCorrectnessTest::test_separable_conv_2d3 +NNOpsCorrectnessTest::test_separable_conv_2d4 +NNOpsCorrectnessTest::test_separable_conv_2d5 +NNOpsCorrectnessTest::test_separable_conv_2d6 +NNOpsCorrectnessTest::test_separable_conv_2d7 +NumpyOneInputOpsDynamicShapeTest::test_argmax_negative_zero +NumpyOneInputOpsDynamicShapeTest::test_argmin_negative_zero +NumpyTwoInputOpsCorrectnessTest::test_logspace +NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank2_float32_false_false +NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank2_float64_false_false +NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float16_false_false +NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float32_false_false +NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank3_float64_false_false +NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float16_false_false +NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float32_false_false +NumpyTwoInputOpsCorrectnessTest::test_matmul_sparse_rank4_float64_false_false +RandomGaussianBlurTest::test_random_erasing_basic +RandomPerspectiveTest::test_random_perspective_bounding_boxes_with_large_scale +RandomPerspectiveTest::test_random_perspective_bounding_boxes_with_small_scale +RandomZoomTest::test_random_zoom_out_correctness +RegularizersTest::test_orthogonal_regularizer +RNNTest::test_go_backwards +SeparableConvCorrectnessTest::test_separable_conv1d0 +SeparableConvCorrectnessTest::test_separable_conv1d1 +SeparableConvCorrectnessTest::test_separable_conv1d2 +SeparableConvCorrectnessTest::test_separable_conv2d0 +SeparableConvCorrectnessTest::test_separable_conv2d1 +SeparableConvCorrectnessTest::test_separable_conv2d2 +SimpleRNNTest::test_correctness +SimpleRNNTest::test_correctness +SimpleRNNTest::test_masking +SimpleRNNTest::test_masking +SimpleRNNTest::test_pass_initial_state +SimpleRNNTest::test_pass_initial_state +SimpleRNNTest::test_return_state +SimpleRNNTest::test_statefulness +SimpleRNNTest::test_statefulness +StackedRNNTest::test_correctness_single_state_stack +StackedRNNTest::test_correctness_two_states_stack +StackedRNNTest::test_statefullness_single_state_stack +StackedRNNTest::test_statefullness_two_states_stack +TestFitLRSchedulesFlow::test_fit_lr_correctness +TestJaxLayer::test_flax_layer_training_independent_bound_method +TestJaxLayer::test_flax_layer_training_rng_state_no_method +TestJaxLayer::test_flax_layer_training_rng_unbound_method +TestJaxLayer::test_flax_layer_training_rng_unbound_method_dtype_policy +TestJaxLayer::test_jax_layer_training_independent +TestJaxLayer::test_jax_layer_training_state +TestJaxLayer::test_jax_layer_training_state_dtype_policy +TestSpectrogram::test_spectrogram_error +TestTrainer::test_loss_weights +TestTrainer::test_nested_inputs +TestTrainer::test_on_batch_methods_eager +TestTrainer::test_on_batch_methods_graph_fn +TestTrainer::test_on_batch_methods_jit \ No newline at end of file diff --git a/requirements-common.txt b/requirements-common.txt index 2fecef1d5946..4d788fa21cbf 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -27,5 +27,4 @@ onnxruntime # onnxscript==0.3.2 breaks LSTM model export. onnxscript!=0.3.2 openvino -# for grain_dataset_adapter_test.py grain diff --git a/requirements-jax-tpu.txt b/requirements-jax-tpu.txt new file mode 100644 index 000000000000..4febbe8e8aab --- /dev/null +++ b/requirements-jax-tpu.txt @@ -0,0 +1,14 @@ +# Tensorflow cpu-only version (needed for testing). +tensorflow-cpu~=2.18.1 +tf2onnx + +# Torch cpu-only version (needed for testing). +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.6.0 + +# Jax with cuda support. +--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html +jax[tpu] +flax + +-r requirements-common.txt diff --git a/requirements-tensorflow-tpu.txt b/requirements-tensorflow-tpu.txt new file mode 100644 index 000000000000..aaac402056bd --- /dev/null +++ b/requirements-tensorflow-tpu.txt @@ -0,0 +1,13 @@ +--find-links https://storage.googleapis.com/libtpu-tf-releases/index.html +tensorflow-tpu==2.19.1 + +tf2onnx + +# Torch cpu-only version (needed for testing). +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.6.0 + +# Jax cpu-only version (needed for testing). +jax + +-r requirements-common.txt