From d002f0a2765c4a1a781b3b1f51717423cd127d28 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Wed, 15 Oct 2025 14:29:26 -0700 Subject: [PATCH] Add backend specific TPU tests for DistributedEmbedding. Under `keras_rs/src/layers/embedding`. --- .github/workflows/actions.yml | 2 +- keras_rs/src/layers/embedding/jax/distributed_embedding_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 5088f571..26744155 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -90,7 +90,7 @@ jobs: run: python3 -c "import jax; print('JAX devices:', jax.devices())" - name: Test with pytest - run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py + run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py keras_rs/src/layers/embedding/${{ matrix.backend }} check_format: name: Check the code format diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py index 1dd3525d..c411c5cc 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py @@ -338,7 +338,7 @@ def test_call( # Trigger layer.build(...) to initialize tables. sample_ids, sample_weights = keras_test_utils.create_random_samples( - feature_configs, ragged=ragged, seed=0 + feature_configs, ragged=ragged, seed=0, max_ids_per_sample=10 ) inputs = layer.preprocess(sample_ids, sample_weights) _ = layer(inputs)