Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
import jax.experimental.sparse as jax_sparse
import jax.lax as lax
import jax.numpy as jnp
import ml_dtypes
import numpy as np
Expand Down Expand Up @@ -529,6 +530,61 @@ def remat(f):
return jax.checkpoint(f)


def all_reduce(x, op="sum", axis_name="model"):
"""
Performs an **all-reduce** operation across all replicas in the specified
distribution axis.

The all-reduce operation computes a reduction (like sum or mean)
of the input tensor `x` across all devices/replicas in the `axis_name`
group, and then broadcasts the result back to all participating devices.

Args:
x: The tensor to reduce.
op: The reduction operation to perform. Common options include "sum"
and "mean". Defaults to "sum".
axis_name: The name of the distribution axis (e.g., "model",
"data") over which to perform the reduction. Defaults to "model".

Returns:
The result of the all-reduce operation, with the same shape as the
input `x`.
"""
if op == "sum":
return lax.psum(x, axis_name=axis_name)
elif op == "mean":
return lax.pmean(x, axis_name=axis_name)
else:
raise ValueError(
f"Unsupported reduction operation: {op}. "
"Supported options are 'sum' and 'mean'."
)


def all_gather(x, axis, axis_name="model"):
"""
Performs an all-gather operation across all replicas in the specified
distribution axis.

The all-gather operation collects the input tensor `x` from all devices
in the `axis_name` group and concatenates them along the specified `axis`.
This is often used in tensor parallelism to combine parts of a tensor
distributed across devices.

Args:
x: The tensor to gather.
axis: The dimension along which to concatenate the gathered tensors.
axis_name: The name of the distribution axis (e.g., "model",
"data") over which to perform the gather.
Defaults to "model".

Returns:
The gathered tensor, which will have a larger size along `axis`
dimension.
"""
return lax.all_gather(x, axis_name=axis_name, axis=axis, tiled=True)


class name_scope(base_name_scope):
def __init__(self, name, **kwargs):
super().__init__(name, **kwargs)
Expand Down
78 changes: 78 additions & 0 deletions keras/src/backend/jax/core_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import os

import jax
Expand All @@ -9,6 +10,8 @@
from keras.src import backend
from keras.src import testing
from keras.src.backend.config import is_nnx_enabled
from keras.src.backend.jax.core import all_gather
from keras.src.backend.jax.core import all_reduce

if is_nnx_enabled():
from flax import nnx
Expand Down Expand Up @@ -66,3 +69,78 @@ def test_keras_variable_nnx_split_merge_sync(self):
state = jax.tree.map(lambda x: x + 1, state)
variable2 = nnx.merge(graphdef, state)
self.assertEqual(variable2._value, variable2.value)


@pytest.mark.skipif(
backend.backend() != "jax",
reason="JAX backend specific test for collective operations.",
)
@pytest.mark.skipif(
jax.local_device_count() < 2,
reason="Requires multiple local devices for testing.",
)
class JaxCollectiveOpsTest(testing.TestCase):
def test_all_reduce_sum(self):
"""Tests the all_reduce operation with the 'sum' reduction."""
num_devices = jax.local_device_count()
local_value = 10.0

local_inputs = jax.numpy.array([local_value] * num_devices)

@functools.partial(
jax.pmap, axis_name="all", devices=jax.devices("cpu")
)
def reduce_sum_fn(x):
return all_reduce(x, op="sum", axis_name="all")

result = reduce_sum_fn(local_inputs)
expected_sum = local_value * num_devices

self.assertTrue(np.allclose(result, expected_sum))
self.assertEqual(result.shape, (num_devices,))

def test_all_reduce_mean(self):
"""Tests the all_reduce operation with the 'mean' reduction."""
num_devices = jax.local_device_count()
local_value = 10.0

local_inputs = jax.numpy.array([local_value] * num_devices)

@functools.partial(
jax.pmap, axis_name="all", devices=jax.devices("cpu")
)
def reduce_mean_fn(x):
return all_reduce(x, op="mean", axis_name="all")

result = reduce_mean_fn(local_inputs)
expected_mean = local_value

self.assertTrue(np.allclose(result, expected_mean))
self.assertEqual(result.shape, (num_devices,))

def test_all_gather(self):
"""Tests the all_gather operation."""
num_devices = jax.local_device_count()
local_data = np.arange(5)

local_inputs = jax.numpy.stack(
[local_data + (i * 5) for i in range(num_devices)]
)

@functools.partial(
jax.pmap, axis_name="all", devices=jax.devices("cpu")
)
def gather_fn(x):
return all_gather(x, axis=0, axis_name="all")

result_array_on_devices = gather_fn(local_inputs)

expected_shape = (num_devices, num_devices * local_data.shape[0])
self.assertEqual(result_array_on_devices.shape, expected_shape)

expected_gathered_data = np.arange(num_devices * local_data.shape[0])

for i in range(num_devices):
self.assertTrue(
np.allclose(result_array_on_devices[i], expected_gathered_data)
)
Loading