99import numpy as np
1010from jax import numpy as jnp
1111from jax .experimental import layout as jax_layout
12+ from jax .experimental import multihost_utils
1213from jax_tpu_embedding .sparsecore .lib .nn import embedding
1314from jax_tpu_embedding .sparsecore .lib .nn import embedding_spec
1415from jax_tpu_embedding .sparsecore .lib .nn import (
@@ -600,31 +601,26 @@ def _sparsecore_preprocess(
600601 # underlying stacked tables specs in the feature specs.
601602
602603 # Aggregate stats across all processes/devices via pmax.
603- num_local_cpu_devices = jax .local_device_count ("cpu" )
604-
605- def pmax_aggregate (x : Any ) -> Any :
606- if not hasattr (x , "ndim" ):
607- x = np .array (x )
608- tiled_x = np .tile (x , (num_local_cpu_devices , * ([1 ] * x .ndim )))
609- return jax .pmap (
610- lambda y : jax .lax .pmax (y , "all_cpus" ), # type: ignore[no-untyped-call]
611- axis_name = "all_cpus" ,
612- backend = "cpu" ,
613- )(tiled_x )[0 ]
614-
615- full_stats = jax .tree .map (pmax_aggregate , stats )
604+ all_stats = multihost_utils .process_allgather (stats )
605+ aggregated_stats = jax .tree .map (
606+ lambda x : jnp .max (x , axis = 0 ), all_stats
607+ )
616608
617609 # Check if stats changed enough to warrant action.
618610 stacked_table_specs = embedding .get_stacked_table_specs (
619611 self ._config .feature_specs
620612 )
621613 changed = any (
622- np .max (full_stats .max_ids_per_partition [stack_name ])
614+ np .max (aggregated_stats .max_ids_per_partition [stack_name ])
623615 > spec .max_ids_per_partition
624- or np .max (full_stats .max_unique_ids_per_partition [stack_name ])
616+ or np .max (
617+ aggregated_stats .max_unique_ids_per_partition [stack_name ]
618+ )
625619 > spec .max_unique_ids_per_partition
626620 or (
627- np .max (full_stats .required_buffer_size_per_sc [stack_name ])
621+ np .max (
622+ aggregated_stats .required_buffer_size_per_sc [stack_name ]
623+ )
628624 * num_sc_per_device
629625 )
630626 > (spec .suggested_coo_buffer_size_per_device or 0 )
@@ -634,7 +630,9 @@ def pmax_aggregate(x: Any) -> Any:
634630 # Update configuration and repeat preprocessing if stats changed.
635631 if changed :
636632 embedding .update_preprocessing_parameters (
637- self ._config .feature_specs , full_stats , num_sc_per_device
633+ self ._config .feature_specs ,
634+ aggregated_stats ,
635+ num_sc_per_device ,
638636 )
639637
640638 # Re-execute preprocessing with consistent input statistics.
0 commit comments