Skip to content

Commit b27a31d

Browse files
committed
Add heuristic for max ids of stacked tables.
Propagate the `max_ids_per_partition` and `max_unique_ids_per_partition` from `TableSpec`s to `StackedTableSpec`s by taking the max from the stacked tables.
1 parent 5713afe commit b27a31d

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""JAX implementation of the TPU embedding layer."""
22

3+
import collections
4+
import dataclasses
35
import math
46
import typing
57
from typing import Any, Mapping, Sequence, Union
@@ -445,6 +447,30 @@ def sparsecore_build(
445447
table_specs = embedding.get_table_specs(feature_specs)
446448
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
447449

450+
# Create new instances of StackTableSpec with updated values that are
451+
# the maximum from stacked tables.
452+
stacked_table_specs = embedding.get_stacked_table_specs(feature_specs)
453+
stacked_table_specs = {
454+
stack_name: dataclasses.replace(
455+
stacked_table_spec,
456+
max_ids_per_partition=max(
457+
table.max_ids_per_partition
458+
for table in table_stacks[stack_name]
459+
),
460+
max_unique_ids_per_partition=max(
461+
table.max_unique_ids_per_partition
462+
for table in table_stacks[stack_name]
463+
),
464+
)
465+
for stack_name, stacked_table_spec in stacked_table_specs.items()
466+
}
467+
468+
# Rewrite the stacked_table_spec in all TableSpecs.
469+
for stack_name, table_specs in table_stacks.items():
470+
stacked_table_spec = stacked_table_specs[stack_name]
471+
for table_spec in table_specs:
472+
table_spec.stacked_table_spec = stacked_table_spec
473+
448474
# Create variables for all stacked tables and slot variables.
449475
with sparsecore_distribution.scope():
450476
self._table_and_slot_variables = {

0 commit comments

Comments
 (0)