|
1 | 1 | """JAX implementation of the TPU embedding layer.""" |
2 | 2 |
|
3 | | -import dataclasses |
4 | 3 | import math |
5 | 4 | import typing |
6 | 5 | from typing import Any, Mapping, Sequence, Union |
@@ -446,29 +445,48 @@ def sparsecore_build( |
446 | 445 | table_specs = embedding.get_table_specs(feature_specs) |
447 | 446 | table_stacks = jte_table_stacking.get_table_stacks(table_specs) |
448 | 447 |
|
449 | | - # Create new instances of StackTableSpec with updated values that are |
450 | | - # the maximum from stacked tables. |
451 | | - stacked_table_specs = embedding.get_stacked_table_specs(feature_specs) |
452 | | - stacked_table_specs = { |
453 | | - stack_name: dataclasses.replace( |
454 | | - stacked_table_spec, |
455 | | - max_ids_per_partition=max( |
456 | | - table.max_ids_per_partition |
457 | | - for table in table_stacks[stack_name] |
458 | | - ), |
459 | | - max_unique_ids_per_partition=max( |
460 | | - table.max_unique_ids_per_partition |
461 | | - for table in table_stacks[stack_name] |
462 | | - ), |
| 448 | + # Update stacked table stats to max of values across involved tables. |
| 449 | + max_ids_per_partition = {} |
| 450 | + max_unique_ids_per_partition = {} |
| 451 | + required_buffer_size_per_device = {} |
| 452 | + id_drop_counters = {} |
| 453 | + for stack_name, stack in table_stacks.items(): |
| 454 | + max_ids_per_partition[stack_name] = np.max( |
| 455 | + np.asarray( |
| 456 | + [s.max_ids_per_partition for s in stack], dtype=np.int32 |
| 457 | + ) |
| 458 | + ) |
| 459 | + max_unique_ids_per_partition[stack_name] = np.max( |
| 460 | + np.asarray( |
| 461 | + [s.max_unique_ids_per_partition for s in stack], |
| 462 | + dtype=np.int32, |
| 463 | + ) |
463 | 464 | ) |
464 | | - for stack_name, stacked_table_spec in stacked_table_specs.items() |
465 | | - } |
466 | 465 |
|
467 | | - # Rewrite the stacked_table_spec in all TableSpecs. |
468 | | - for stack_name, table_specs in table_stacks.items(): |
469 | | - stacked_table_spec = stacked_table_specs[stack_name] |
470 | | - for table_spec in table_specs: |
471 | | - table_spec.stacked_table_spec = stacked_table_spec |
| 466 | + # Only set the suggested buffer size if set on any individual table. |
| 467 | + valid_buffer_sizes = [ |
| 468 | + s.suggested_coo_buffer_size_per_device |
| 469 | + for s in stack |
| 470 | + if s.suggested_coo_buffer_size_per_device is not None |
| 471 | + ] |
| 472 | + if valid_buffer_sizes: |
| 473 | + required_buffer_size_per_device[stack_name] = np.max( |
| 474 | + np.asarray(valid_buffer_sizes, dtype=np.int32) |
| 475 | + ) |
| 476 | + |
| 477 | + id_drop_counters[stack_name] = 0 |
| 478 | + |
| 479 | + aggregated_stats = embedding.SparseDenseMatmulInputStats( |
| 480 | + max_ids_per_partition=max_ids_per_partition, |
| 481 | + max_unique_ids_per_partition=max_unique_ids_per_partition, |
| 482 | + required_buffer_size_per_sc=required_buffer_size_per_device, |
| 483 | + id_drop_counters=id_drop_counters, |
| 484 | + ) |
| 485 | + embedding.update_preprocessing_parameters( |
| 486 | + feature_specs, |
| 487 | + aggregated_stats, |
| 488 | + num_sc_per_device, |
| 489 | + ) |
472 | 490 |
|
473 | 491 | # Create variables for all stacked tables and slot variables. |
474 | 492 | with sparsecore_distribution.scope(): |
|
0 commit comments