Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,28 @@
if backend() == "tensorflow":
from keras.src.backend.tensorflow import * # noqa: F403
from keras.src.backend.tensorflow.core import Variable as BackendVariable

distributed_backend = None
elif backend() == "jax":
from keras.src.backend.jax import * # noqa: F403
from keras.src.backend.jax.core import Variable as BackendVariable

elif backend() == "torch":
from keras.src.backend.torch import * # noqa: F403
from keras.src.backend.torch.core import Variable as BackendVariable

distribution_lib = None
elif backend() == "numpy":
from keras.src.backend.numpy import * # noqa: F403
from keras.src.backend.numpy.core import Variable as BackendVariable

distribution_lib = None
distributed_backend = None
elif backend() == "openvino":
from keras.src.backend.openvino import * # noqa: F403
from keras.src.backend.openvino.core import Variable as BackendVariable

distribution_lib = None
distributed_backend = None
else:
raise ValueError(f"Unable to import backend : {backend()}")

Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from keras.src.backend.config import is_nnx_enabled
from keras.src.backend.jax import core
from keras.src.backend.jax import distributed_backend
from keras.src.backend.jax import distribution_lib
from keras.src.backend.jax import image
from keras.src.backend.jax import linalg
Expand Down
2 changes: 2 additions & 0 deletions keras/src/backend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from keras.src.backend.common.name_scope import name_scope
from keras.src.backend.torch import core
from keras.src.backend.torch import distributed_backend
from keras.src.backend.torch import distribution_lib
from keras.src.backend.torch import image
from keras.src.backend.torch import linalg
from keras.src.backend.torch import math
Expand Down
142 changes: 142 additions & 0 deletions keras/src/backend/torch/distributed_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import os

import torch
import torch.distributed as dist


def get_device_info():
"""Retrieves information about the available PyTorch devices.

This function queries PyTorch to identify the type and number of
available computational devices (e.g., CPU, GPU).

Returns:
dict: A dictionary containing the backend name ('torch'), a list of
device string representations, and the total count of devices.
"""
if torch.cuda.is_available():
device_count = torch.cuda.device_count()
devices = [
f"cuda:{i} ({torch.cuda.get_device_name(i)})"
for i in range(device_count)
]
backend = "torch (CUDA)"
else:
device_count = 1
devices = ["cpu"]
backend = "torch (CPU)"

return {
"backend": backend,
"devices": devices,
"device_count": device_count,
}


def is_multi_device_capable():
"""Checks if more than one device is available for distributed computation.

Returns:
bool: True if the PyTorch distributed environment is initialized and
has a world size greater than one, False otherwise.
"""
if dist.is_available() and dist.is_initialized():
return dist.get_world_size() > 1
elif torch.cuda.is_available():
return torch.cuda.device_count() > 1
return False


def setup_distributed_environment():
"""
A helper function to initialize the distributed process group.

This is a prerequisite for using the communication operations.
In a real application, this would be called at the start of the script.
It uses environment variables commonly set by launchers like torchrun.
"""
if dist.is_available() and not dist.is_initialized():
required_env_vars = ["MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE"]
if not all(v in os.environ for v in required_env_vars):
return False

dist.init_process_group(backend="nccl")
return True
elif dist.is_initialized():
return True
else:
return False


def get_communication_ops():
"""Provides a dictionary of PyTorch collective communication operations.

Note: The torch.distributed process group must be initialized before
calling these functions.

Returns:
dict: A dictionary mapping operation names (e.g., 'all_reduce') to their
corresponding PyTorch implementation functions.
"""

def all_reduce(x, op="sum"):
"""Reduces a tensor across all devices in the process group.

This function performs a collective reduction operation
across all devices in the distributed group.

Args:
x (torch.Tensor): The input tensor on the local device.
op (str, optional): The reduction operation to perform. Supported
values are 'sum' and 'mean'. Defaults to 'sum'.

Returns:
torch.Tensor: The reduced tensor, which is identical across all
devices participating in the reduction.
"""
if not (dist.is_available() and dist.is_initialized()):
return x

if op == "sum":
reduce_op = dist.ReduceOp.SUM
elif op == "mean":
reduce_op = dist.ReduceOp.AVG
else:
raise ValueError(
f"Unsupported reduction operation: {op}. "
"Supported options are 'sum' and 'mean'."
)

result = x.clone()
dist.all_reduce(result, op=reduce_op)
return result

def all_gather(x, axis):
"""Gathers and concatenates tensors from all devices.

This function takes the local tensor `x` from each device and
concatenates them along the specified tensor `axis` to form a single,
larger tensor that is then replicated on all participating devices.

Args:
x (torch.Tensor): The input tensor shard on the local device.
axis (int): The tensor axis along which to concatenate the gathered
shards.

Returns:
torch.Tensor: The full, gathered tensor, which is identical across
all devices participating in the gather.
"""
if not (dist.is_available() and dist.is_initialized()):
return x

world_size = dist.get_world_size()
tensor_list = [torch.empty_like(x) for _ in range(world_size)]

dist.all_gather(tensor_list, x)
return torch.cat(tensor_list, dim=axis)

return {
"all_reduce": all_reduce,
"all_gather": all_gather,
}
51 changes: 51 additions & 0 deletions keras/src/backend/torch/distributed_backend_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
import torch

from keras.src import backend
from keras.src.backend import distributed_backend


@pytest.mark.skipif(
backend.backend() != "torch",
reason="Torch Backend specific test",
)
class TestPytorchDistributedFunctions:
"""Unit tests for the PyTorch distributed backend standalone functions."""

def test_get_device_info(self):
"""Test retrieving device information from the PyTorch backend."""
info = distributed_backend.get_device_info()
assert info["backend"] == "torch (CPU)"
assert isinstance(info["devices"], list)
assert isinstance(info["device_count"], int)
assert info["device_count"] > 0
assert len(info["devices"]) == info["device_count"]
if torch.cuda.is_available():
assert info["device_count"] == torch.cuda.device_count()
else:
assert info["device_count"] == 1
assert info["devices"] == ["cpu"]

def test_is_multi_device_capable(self):
"""Test the boolean check for multi-device capability."""
assert isinstance(distributed_backend.is_multi_device_capable(), bool)

def test_communication_ops_simulation_logic(self):
"""Test the simulated communication ops in a single-device context."""
comm_ops = distributed_backend.get_communication_ops()
device_info = distributed_backend.get_device_info()
world_size = device_info.get("device_count", 1)

# Test all_reduce
x_reduce = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
reduced = comm_ops["all_reduce"](x_reduce, op="sum")
expected_reduce = (
x_reduce * float(world_size) if world_size > 1 else x_reduce
)
torch.testing.assert_close(reduced, expected_reduce)

# Test all_gather
x_gather = torch.tensor([[1.0, 2.0]])
gathered = comm_ops["all_gather"](x_gather, axis=0)
expected_gather = torch.cat([x_gather] * world_size, dim=0)
torch.testing.assert_close(gathered, expected_gather)
Loading
Loading