Skip to content

Commit 5d7f18a

Browse files
Use the JTE API to set stacked table stats... (#169)
* Use the JTE API to set stacked table stats to the maximum of the input table specs. This allows setting parameters like `max_ids_per_partition` and `max_unique_ids_per_partition`, `suggested_coo_buffer_size` for stacked tables with auto-stacking. Although the heuristic may not be optimal, this at least provides a method for directly setting the values in the stacked tables, and is consistent with the default values if nothing is set. Uses the `jax_tpu_embedding` API for future-proofing. * Update keras_rs/src/layers/embedding/jax/distributed_embedding.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 04fb241 commit 5d7f18a

File tree

1 file changed

+40
-22
lines changed

1 file changed

+40
-22
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""JAX implementation of the TPU embedding layer."""
22

3-
import dataclasses
43
import math
54
import typing
65
from typing import Any, Mapping, Sequence, Union
@@ -446,29 +445,48 @@ def sparsecore_build(
446445
table_specs = embedding.get_table_specs(feature_specs)
447446
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
448447

449-
# Create new instances of StackTableSpec with updated values that are
450-
# the maximum from stacked tables.
451-
stacked_table_specs = embedding.get_stacked_table_specs(feature_specs)
452-
stacked_table_specs = {
453-
stack_name: dataclasses.replace(
454-
stacked_table_spec,
455-
max_ids_per_partition=max(
456-
table.max_ids_per_partition
457-
for table in table_stacks[stack_name]
458-
),
459-
max_unique_ids_per_partition=max(
460-
table.max_unique_ids_per_partition
461-
for table in table_stacks[stack_name]
462-
),
448+
# Update stacked table stats to max of values across involved tables.
449+
max_ids_per_partition = {}
450+
max_unique_ids_per_partition = {}
451+
required_buffer_size_per_device = {}
452+
id_drop_counters = {}
453+
for stack_name, stack in table_stacks.items():
454+
max_ids_per_partition[stack_name] = np.max(
455+
np.asarray(
456+
[s.max_ids_per_partition for s in stack], dtype=np.int32
457+
)
458+
)
459+
max_unique_ids_per_partition[stack_name] = np.max(
460+
np.asarray(
461+
[s.max_unique_ids_per_partition for s in stack],
462+
dtype=np.int32,
463+
)
463464
)
464-
for stack_name, stacked_table_spec in stacked_table_specs.items()
465-
}
466465

467-
# Rewrite the stacked_table_spec in all TableSpecs.
468-
for stack_name, table_specs in table_stacks.items():
469-
stacked_table_spec = stacked_table_specs[stack_name]
470-
for table_spec in table_specs:
471-
table_spec.stacked_table_spec = stacked_table_spec
466+
# Only set the suggested buffer size if set on any individual table.
467+
valid_buffer_sizes = [
468+
s.suggested_coo_buffer_size_per_device
469+
for s in stack
470+
if s.suggested_coo_buffer_size_per_device is not None
471+
]
472+
if valid_buffer_sizes:
473+
required_buffer_size_per_device[stack_name] = np.max(
474+
np.asarray(valid_buffer_sizes, dtype=np.int32)
475+
)
476+
477+
id_drop_counters[stack_name] = 0
478+
479+
aggregated_stats = embedding.SparseDenseMatmulInputStats(
480+
max_ids_per_partition=max_ids_per_partition,
481+
max_unique_ids_per_partition=max_unique_ids_per_partition,
482+
required_buffer_size_per_sc=required_buffer_size_per_device,
483+
id_drop_counters=id_drop_counters,
484+
)
485+
embedding.update_preprocessing_parameters(
486+
feature_specs,
487+
aggregated_stats,
488+
num_sc_per_device,
489+
)
472490

473491
# Create variables for all stacked tables and slot variables.
474492
with sparsecore_distribution.scope():

0 commit comments

Comments
 (0)