|
1 | 1 | """JAX implementation of the TPU embedding layer.""" |
2 | 2 |
|
| 3 | +import collections |
| 4 | +import dataclasses |
3 | 5 | import math |
4 | 6 | import typing |
5 | 7 | from typing import Any, Mapping, Sequence, Union |
@@ -445,6 +447,30 @@ def sparsecore_build( |
445 | 447 | table_specs = embedding.get_table_specs(feature_specs) |
446 | 448 | table_stacks = jte_table_stacking.get_table_stacks(table_specs) |
447 | 449 |
|
| 450 | + # Create new instances of StackTableSpec with updated values that are |
| 451 | + # the maximum from stacked tables. |
| 452 | + stacked_table_specs = embedding.get_stacked_table_specs(feature_specs) |
| 453 | + stacked_table_specs = { |
| 454 | + stack_name: dataclasses.replace( |
| 455 | + stacked_table_spec, |
| 456 | + max_ids_per_partition=max( |
| 457 | + table.max_ids_per_partition |
| 458 | + for table in table_stacks[stack_name] |
| 459 | + ), |
| 460 | + max_unique_ids_per_partition=max( |
| 461 | + table.max_unique_ids_per_partition |
| 462 | + for table in table_stacks[stack_name] |
| 463 | + ), |
| 464 | + ) |
| 465 | + for stack_name, stacked_table_spec in stacked_table_specs.items() |
| 466 | + } |
| 467 | + |
| 468 | + # Rewrite the stacked_table_spec in all TableSpecs. |
| 469 | + for stack_name, table_specs in table_stacks.items(): |
| 470 | + stacked_table_spec = stacked_table_specs[stack_name] |
| 471 | + for table_spec in table_specs: |
| 472 | + table_spec.stacked_table_spec = stacked_table_spec |
| 473 | + |
448 | 474 | # Create variables for all stacked tables and slot variables. |
449 | 475 | with sparsecore_distribution.scope(): |
450 | 476 | self._table_and_slot_variables = { |
|
0 commit comments