Skip to content

Commit 6887dcb

Browse files
authored
Fix CPU stats aggregation for JAX multihost. (#164)
Replaces custom pmap+pmax with the special-purpose multihost_utils.process_allgather. Tested in a pseudo multihost (multiprocess) test.
1 parent 9f1557c commit 6887dcb

File tree

1 file changed

+15
-17
lines changed

1 file changed

+15
-17
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
from jax import numpy as jnp
1111
from jax.experimental import layout as jax_layout
12+
from jax.experimental import multihost_utils
1213
from jax_tpu_embedding.sparsecore.lib.nn import embedding
1314
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
1415
from jax_tpu_embedding.sparsecore.lib.nn import (
@@ -600,31 +601,26 @@ def _sparsecore_preprocess(
600601
# underlying stacked tables specs in the feature specs.
601602

602603
# Aggregate stats across all processes/devices via pmax.
603-
num_local_cpu_devices = jax.local_device_count("cpu")
604-
605-
def pmax_aggregate(x: Any) -> Any:
606-
if not hasattr(x, "ndim"):
607-
x = np.array(x)
608-
tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
609-
return jax.pmap(
610-
lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
611-
axis_name="all_cpus",
612-
backend="cpu",
613-
)(tiled_x)[0]
614-
615-
full_stats = jax.tree.map(pmax_aggregate, stats)
604+
all_stats = multihost_utils.process_allgather(stats)
605+
aggregated_stats = jax.tree.map(
606+
lambda x: jnp.max(x, axis=0), all_stats
607+
)
616608

617609
# Check if stats changed enough to warrant action.
618610
stacked_table_specs = embedding.get_stacked_table_specs(
619611
self._config.feature_specs
620612
)
621613
changed = any(
622-
np.max(full_stats.max_ids_per_partition[stack_name])
614+
np.max(aggregated_stats.max_ids_per_partition[stack_name])
623615
> spec.max_ids_per_partition
624-
or np.max(full_stats.max_unique_ids_per_partition[stack_name])
616+
or np.max(
617+
aggregated_stats.max_unique_ids_per_partition[stack_name]
618+
)
625619
> spec.max_unique_ids_per_partition
626620
or (
627-
np.max(full_stats.required_buffer_size_per_sc[stack_name])
621+
np.max(
622+
aggregated_stats.required_buffer_size_per_sc[stack_name]
623+
)
628624
* num_sc_per_device
629625
)
630626
> (spec.suggested_coo_buffer_size_per_device or 0)
@@ -634,7 +630,9 @@ def pmax_aggregate(x: Any) -> Any:
634630
# Update configuration and repeat preprocessing if stats changed.
635631
if changed:
636632
embedding.update_preprocessing_parameters(
637-
self._config.feature_specs, full_stats, num_sc_per_device
633+
self._config.feature_specs,
634+
aggregated_stats,
635+
num_sc_per_device,
638636
)
639637

640638
# Re-execute preprocessing with consistent input statistics.

0 commit comments

Comments
 (0)