Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions keras_rs/src/layers/embedding/jax/distributed_embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""JAX implementation of the TPU embedding layer."""

import dataclasses
import math
import typing
from typing import Any, Mapping, Sequence, Union
Expand Down Expand Up @@ -445,6 +446,30 @@ def sparsecore_build(
table_specs = embedding.get_table_specs(feature_specs)
table_stacks = jte_table_stacking.get_table_stacks(table_specs)

# Create new instances of StackTableSpec with updated values that are
# the maximum from stacked tables.
stacked_table_specs = embedding.get_stacked_table_specs(feature_specs)
stacked_table_specs = {
stack_name: dataclasses.replace(
stacked_table_spec,
max_ids_per_partition=max(
table.max_ids_per_partition
for table in table_stacks[stack_name]
),
max_unique_ids_per_partition=max(
table.max_unique_ids_per_partition
for table in table_stacks[stack_name]
),
)
for stack_name, stacked_table_spec in stacked_table_specs.items()
}

# Rewrite the stacked_table_spec in all TableSpecs.
for stack_name, table_specs in table_stacks.items():
stacked_table_spec = stacked_table_specs[stack_name]
for table_spec in table_specs:
table_spec.stacked_table_spec = stacked_table_spec

# Create variables for all stacked tables and slot variables.
with sparsecore_distribution.scope():
self._table_and_slot_variables = {
Expand Down