From b78809d48473985e8bdf11496416a01b6e8b2fa0 Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Thu, 30 Oct 2025 09:50:41 -0700 Subject: [PATCH 1/2] Use the JTE API to set stacked table stats to the maximum of the input table specs. This allows setting parameters like `max_ids_per_partition` and `max_unique_ids_per_partition`, `suggested_coo_buffer_size` for stacked tables with auto-stacking. Although the heuristic may not be optimal, this at least provides a method for directly setting the values in the stacked tables, and is consistent with the default values if nothing is set. Uses the `jax_tpu_embedding` API for future-proofing. --- .../embedding/jax/distributed_embedding.py | 65 ++++++++++++------- 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index f27e109a..87a85089 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -1,6 +1,5 @@ """JAX implementation of the TPU embedding layer.""" -import dataclasses import math import typing from typing import Any, Mapping, Sequence, Union @@ -446,29 +445,51 @@ def sparsecore_build( table_specs = embedding.get_table_specs(feature_specs) table_stacks = jte_table_stacking.get_table_stacks(table_specs) - # Create new instances of StackTableSpec with updated values that are - # the maximum from stacked tables. - stacked_table_specs = embedding.get_stacked_table_specs(feature_specs) - stacked_table_specs = { - stack_name: dataclasses.replace( - stacked_table_spec, - max_ids_per_partition=max( - table.max_ids_per_partition - for table in table_stacks[stack_name] - ), - max_unique_ids_per_partition=max( - table.max_unique_ids_per_partition - for table in table_stacks[stack_name] - ), + # Update stacked table stats to max of values across involved tables. + max_ids_per_partition = {} + max_unique_ids_per_partition = {} + required_buffer_size_per_device = {} + id_drop_counters = {} + for stack_name, stack in table_stacks.items(): + max_ids_per_partition[stack_name] = np.max( + np.asarray( + [s.max_ids_per_partition for s in stack], dtype=np.int32 + ) + ) + max_unique_ids_per_partition[stack_name] = np.max( + np.asarray( + [s.max_unique_ids_per_partition for s in stack], + dtype=np.int32, + ) ) - for stack_name, stacked_table_spec in stacked_table_specs.items() - } - # Rewrite the stacked_table_spec in all TableSpecs. - for stack_name, table_specs in table_stacks.items(): - stacked_table_spec = stacked_table_specs[stack_name] - for table_spec in table_specs: - table_spec.stacked_table_spec = stacked_table_spec + # Only set the suggested buffer size if set on any individual table. + suggested_buffer_sizes = np.asarray( + [ + s.suggested_coo_buffer_size_per_device + for s in stack + if s.suggested_coo_buffer_size_per_device is not None + ], + dtype=np.int32, + ) + if len(suggested_buffer_sizes) > 0: + required_buffer_size_per_device[stack_name] = np.max( + suggested_buffer_sizes + ) + + id_drop_counters[stack_name] = 0 + + aggregated_stats = embedding.SparseDenseMatmulInputStats( + max_ids_per_partition=max_ids_per_partition, + max_unique_ids_per_partition=max_unique_ids_per_partition, + required_buffer_size_per_sc=required_buffer_size_per_device, + id_drop_counters=id_drop_counters, + ) + embedding.update_preprocessing_parameters( + feature_specs, + aggregated_stats, + num_sc_per_device, + ) # Create variables for all stacked tables and slot variables. with sparsecore_distribution.scope(): From f89cd036436af23c5e01d5f542777b0fda63e219 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=2E=20Antonio=20S=C3=A1nchez?= Date: Thu, 30 Oct 2025 12:17:29 -0700 Subject: [PATCH 2/2] Update keras_rs/src/layers/embedding/jax/distributed_embedding.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../embedding/jax/distributed_embedding.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 87a85089..c4e647bf 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -464,17 +464,14 @@ def sparsecore_build( ) # Only set the suggested buffer size if set on any individual table. - suggested_buffer_sizes = np.asarray( - [ - s.suggested_coo_buffer_size_per_device - for s in stack - if s.suggested_coo_buffer_size_per_device is not None - ], - dtype=np.int32, - ) - if len(suggested_buffer_sizes) > 0: + valid_buffer_sizes = [ + s.suggested_coo_buffer_size_per_device + for s in stack + if s.suggested_coo_buffer_size_per_device is not None + ] + if valid_buffer_sizes: required_buffer_size_per_device[stack_name] = np.max( - suggested_buffer_sizes + np.asarray(valid_buffer_sizes, dtype=np.int32) ) id_drop_counters[stack_name] = 0