diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 5088f571..26744155 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -90,7 +90,7 @@ jobs: run: python3 -c "import jax; print('JAX devices:', jax.devices())" - name: Test with pytest - run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py + run: pytest keras_rs/src/layers/embedding/distributed_embedding_test.py keras_rs/src/layers/embedding/${{ matrix.backend }} check_format: name: Check the code format diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py index 1dd3525d..c411c5cc 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py @@ -338,7 +338,7 @@ def test_call( # Trigger layer.build(...) to initialize tables. sample_ids, sample_weights = keras_test_utils.create_random_samples( - feature_configs, ragged=ragged, seed=0 + feature_configs, ragged=ragged, seed=0, max_ids_per_sample=10 ) inputs = layer.preprocess(sample_ids, sample_weights) _ = layer(inputs)