diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 15f1af2145d5..ecae3a6c7631 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -37,24 +37,28 @@ if backend() == "tensorflow": from keras.src.backend.tensorflow import * # noqa: F403 from keras.src.backend.tensorflow.core import Variable as BackendVariable + + distributed_backend = None elif backend() == "jax": from keras.src.backend.jax import * # noqa: F403 from keras.src.backend.jax.core import Variable as BackendVariable + elif backend() == "torch": from keras.src.backend.torch import * # noqa: F403 from keras.src.backend.torch.core import Variable as BackendVariable - distribution_lib = None elif backend() == "numpy": from keras.src.backend.numpy import * # noqa: F403 from keras.src.backend.numpy.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None elif backend() == "openvino": from keras.src.backend.openvino import * # noqa: F403 from keras.src.backend.openvino.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None else: raise ValueError(f"Unable to import backend : {backend()}") diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 89ac0fa71c8c..0a275fb70cf1 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -1,5 +1,6 @@ from keras.src.backend.config import is_nnx_enabled from keras.src.backend.jax import core +from keras.src.backend.jax import distributed_backend from keras.src.backend.jax import distribution_lib from keras.src.backend.jax import image from keras.src.backend.jax import linalg diff --git a/keras/src/backend/torch/__init__.py b/keras/src/backend/torch/__init__.py index 371a62cd0f52..a7d0405a5567 100644 --- a/keras/src/backend/torch/__init__.py +++ b/keras/src/backend/torch/__init__.py @@ -16,6 +16,8 @@ from keras.src.backend.common.name_scope import name_scope from keras.src.backend.torch import core +from keras.src.backend.torch import distributed_backend +from keras.src.backend.torch import distribution_lib from keras.src.backend.torch import image from keras.src.backend.torch import linalg from keras.src.backend.torch import math diff --git a/keras/src/backend/torch/distributed_backend.py b/keras/src/backend/torch/distributed_backend.py new file mode 100644 index 000000000000..af00b07f87fe --- /dev/null +++ b/keras/src/backend/torch/distributed_backend.py @@ -0,0 +1,142 @@ +import os + +import torch +import torch.distributed as dist + + +def get_device_info(): + """Retrieves information about the available PyTorch devices. + + This function queries PyTorch to identify the type and number of + available computational devices (e.g., CPU, GPU). + + Returns: + dict: A dictionary containing the backend name ('torch'), a list of + device string representations, and the total count of devices. + """ + if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + devices = [ + f"cuda:{i} ({torch.cuda.get_device_name(i)})" + for i in range(device_count) + ] + backend = "torch (CUDA)" + else: + device_count = 1 + devices = ["cpu"] + backend = "torch (CPU)" + + return { + "backend": backend, + "devices": devices, + "device_count": device_count, + } + + +def is_multi_device_capable(): + """Checks if more than one device is available for distributed computation. + + Returns: + bool: True if the PyTorch distributed environment is initialized and + has a world size greater than one, False otherwise. + """ + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() > 1 + elif torch.cuda.is_available(): + return torch.cuda.device_count() > 1 + return False + + +def setup_distributed_environment(): + """ + A helper function to initialize the distributed process group. + + This is a prerequisite for using the communication operations. + In a real application, this would be called at the start of the script. + It uses environment variables commonly set by launchers like torchrun. + """ + if dist.is_available() and not dist.is_initialized(): + required_env_vars = ["MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE"] + if not all(v in os.environ for v in required_env_vars): + return False + + dist.init_process_group(backend="nccl") + return True + elif dist.is_initialized(): + return True + else: + return False + + +def get_communication_ops(): + """Provides a dictionary of PyTorch collective communication operations. + + Note: The torch.distributed process group must be initialized before + calling these functions. + + Returns: + dict: A dictionary mapping operation names (e.g., 'all_reduce') to their + corresponding PyTorch implementation functions. + """ + + def all_reduce(x, op="sum"): + """Reduces a tensor across all devices in the process group. + + This function performs a collective reduction operation + across all devices in the distributed group. + + Args: + x (torch.Tensor): The input tensor on the local device. + op (str, optional): The reduction operation to perform. Supported + values are 'sum' and 'mean'. Defaults to 'sum'. + + Returns: + torch.Tensor: The reduced tensor, which is identical across all + devices participating in the reduction. + """ + if not (dist.is_available() and dist.is_initialized()): + return x + + if op == "sum": + reduce_op = dist.ReduceOp.SUM + elif op == "mean": + reduce_op = dist.ReduceOp.AVG + else: + raise ValueError( + f"Unsupported reduction operation: {op}. " + "Supported options are 'sum' and 'mean'." + ) + + result = x.clone() + dist.all_reduce(result, op=reduce_op) + return result + + def all_gather(x, axis): + """Gathers and concatenates tensors from all devices. + + This function takes the local tensor `x` from each device and + concatenates them along the specified tensor `axis` to form a single, + larger tensor that is then replicated on all participating devices. + + Args: + x (torch.Tensor): The input tensor shard on the local device. + axis (int): The tensor axis along which to concatenate the gathered + shards. + + Returns: + torch.Tensor: The full, gathered tensor, which is identical across + all devices participating in the gather. + """ + if not (dist.is_available() and dist.is_initialized()): + return x + + world_size = dist.get_world_size() + tensor_list = [torch.empty_like(x) for _ in range(world_size)] + + dist.all_gather(tensor_list, x) + return torch.cat(tensor_list, dim=axis) + + return { + "all_reduce": all_reduce, + "all_gather": all_gather, + } diff --git a/keras/src/backend/torch/distributed_backend_test.py b/keras/src/backend/torch/distributed_backend_test.py new file mode 100644 index 000000000000..70ba5caab1ae --- /dev/null +++ b/keras/src/backend/torch/distributed_backend_test.py @@ -0,0 +1,51 @@ +import pytest +import torch + +from keras.src import backend +from keras.src.backend import distributed_backend + + +@pytest.mark.skipif( + backend.backend() != "torch", + reason="Torch Backend specific test", +) +class TestPytorchDistributedFunctions: + """Unit tests for the PyTorch distributed backend standalone functions.""" + + def test_get_device_info(self): + """Test retrieving device information from the PyTorch backend.""" + info = distributed_backend.get_device_info() + assert info["backend"] == "torch (CPU)" + assert isinstance(info["devices"], list) + assert isinstance(info["device_count"], int) + assert info["device_count"] > 0 + assert len(info["devices"]) == info["device_count"] + if torch.cuda.is_available(): + assert info["device_count"] == torch.cuda.device_count() + else: + assert info["device_count"] == 1 + assert info["devices"] == ["cpu"] + + def test_is_multi_device_capable(self): + """Test the boolean check for multi-device capability.""" + assert isinstance(distributed_backend.is_multi_device_capable(), bool) + + def test_communication_ops_simulation_logic(self): + """Test the simulated communication ops in a single-device context.""" + comm_ops = distributed_backend.get_communication_ops() + device_info = distributed_backend.get_device_info() + world_size = device_info.get("device_count", 1) + + # Test all_reduce + x_reduce = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + reduced = comm_ops["all_reduce"](x_reduce, op="sum") + expected_reduce = ( + x_reduce * float(world_size) if world_size > 1 else x_reduce + ) + torch.testing.assert_close(reduced, expected_reduce) + + # Test all_gather + x_gather = torch.tensor([[1.0, 2.0]]) + gathered = comm_ops["all_gather"](x_gather, axis=0) + expected_gather = torch.cat([x_gather] * world_size, dim=0) + torch.testing.assert_close(gathered, expected_gather) diff --git a/keras/src/backend/torch/distribution_lib.py b/keras/src/backend/torch/distribution_lib.py new file mode 100644 index 000000000000..0d8c18de4bf7 --- /dev/null +++ b/keras/src/backend/torch/distribution_lib.py @@ -0,0 +1,413 @@ +"""Utilities for distribution strategy with Torch backend. + +This file contains the core Torch distribution primitives from Keras, +along with higher-level device management and auto-configuration utilities. +This version does not use try-except blocks for error handling. +""" + +import logging +import os +from typing import Dict +from typing import List +from typing import Optional + +import numpy as np +import torch +import torch.distributed as dist + +from keras.src.backend.common import global_state +from keras.src.random import seed_generator +from keras.src.utils import rng_utils + +logger = logging.getLogger(__name__) + + +def list_devices(device_type=None): + """Return all the available devices based on the device type. + + Note that this should return the global devices in a distributed setting. + + Args: + device_type: string of `"cpu"`, `"gpu"`. Defaults to `"gpu"` if + available when device_type is not provided. Otherwise will return + the `"cpu"` devices. `"tpu"` is not supported by the default + torch backend. + + Return: + List of devices that are available for distribute computation. + """ + if device_type: + device_type = device_type.lower() + else: + device_type = "cuda" if torch.cuda.is_available() else "cpu" + + if device_type in ("gpu", "cuda"): + if not torch.cuda.is_available(): + return [] + return [f"cuda:{i}" for i in range(torch.cuda.device_count())] + elif device_type == "cpu": + return ["cpu:0"] + elif device_type == "tpu": + logger.warning( + "TPU device type is not supported by the default " + "PyTorch backend. Use the `torch_xla` package." + ) + return [] + raise ValueError(f"Unknown device type: {device_type}") + + +def get_device_info(device_id: str) -> Dict[str, any]: + """ + Get detailed information about a specific device. + + Args: + device_id: Device identifier (e.g., 'cuda:0', 'cpu:0') + + Returns: + Dictionary containing device information + """ + device_info = { + "id": device_id, + "type": None, + "index": None, + "memory": None, + "capabilities": None, + } + + device_type, device_index = device_id.split(":") + device_type_map = {"cuda": "GPU", "cpu": "CPU"} + device_info["type"] = device_type_map.get(device_type, device_type.upper()) + device_info["index"] = int(device_index) + + return device_info + + +def get_best_devices(count: int = 1) -> List[str]: + """ + Get the best available devices for tensor parallelism. + + Args: + count: Number of devices needed + + Returns: + List of best device identifiers + """ + all_devices = list_devices("cuda") + if not all_devices: + all_devices = list_devices("cpu") + + if count <= 0: + return [] + + if count > len(all_devices): + logger.warning( + f"Requested {count} devices but only {len(all_devices)} available" + ) + count = len(all_devices) + + return all_devices[:count] + + +def get_device_backend(device_type: str) -> str: + """ + Get the recommended backend for a device type. + + Args: + device_type: Device type ('tpu', 'gpu', 'cpu') + + Returns: + Recommended backend name + """ + backend_mapping = {"gpu": "torch", "cuda": "torch", "cpu": "torch"} + + return backend_mapping.get(device_type.lower(), "torch") + + +def validate_device_placement(device_id: str) -> bool: + """ + Validate if a device can be used for tensor operations. + + Args: + device_id: Device identifier + + Returns: + True if device is valid and available + """ + if ":" not in device_id: + return False + + device_type = device_id.split(":")[0] + known_device_types = ("cpu", "gpu", "cuda", "tpu") + if device_type not in known_device_types: + return False + + all_devices = list_devices(device_type) + return device_id in all_devices + + +def get_device_memory_info(device_id: str) -> Optional[Dict[str, any]]: + """ + Get memory information for a device (if available). + + Args: + device_id: Device identifier + + Returns: + Memory information dictionary or None if not available + """ + if device_id.startswith("cuda:"): + return { + "type": "GPU", + "index": int(device_id.split(":")[1]), + "memory": "Available", + } + elif device_id.startswith("cpu:"): + return { + "type": "CPU", + "index": int(device_id.split(":")[1]), + "memory": "System RAM", + } + + return None + + +def auto_configure_tensor_parallel( + world_size: int = None, backend: str = None +) -> Dict[str, any]: + """ + Automatically configure tensor parallelism with the best available devices. + + Args: + world_size: Number of devices to use (if None, uses all available GPUs) + backend: Backend to use (if None, will be set to 'torch') + + Returns: + Configuration dictionary with devices, backend, and other settings + """ + all_devices = list_devices() + + if not all_devices: + raise RuntimeError("No devices available for tensor parallelism") + + if world_size is None: + world_size = len(all_devices) + else: + world_size = min(world_size, len(all_devices)) + + selected_devices = all_devices[:world_size] + + recommended_backend = "torch" + + config = { + "devices": selected_devices, + "world_size": world_size, + "backend": recommended_backend, + } + + logger.info(f"Auto-configured tensor parallelism: {config}") + return config + + +def distribute_variable(value, layout): + """Create a distributed variable for PyTorch. + + This function creates a `torch.Tensor` distributed according to the given + layout. In PyTorch, variables and tensors are unified in the `Tensor` class. + + Args: + value: The initial value of the variable as a `torch.Tensor`. + layout: `TensorLayout` for the created variable, or a PyTorch-supported + layout instance (e.g., a list of `Placement` types). + + Returns: + `torch.Tensor` which is the distributed variable. + """ + return distribute_tensor(value, layout) + + +def distribute_tensor(tensor, layout): + """Distribute the tensor based on the layout. + + Args: + tensor: `torch.Tensor` that needs to be distributed. + layout: `TensorLayout` for the created variable, or a PyTorch-supported + layout instance (e.g., a list of `Placement` types). + + Returns: + Distributed `torch.Tensor`. + """ + # Avoid circular imports. + from keras.src.distribution import TensorLayout + + if isinstance(layout, TensorLayout): + placements = layout.backend_layout + device_mesh = layout.device_mesh.backend_mesh + else: + raise ValueError( + "Directly passing backend layout is not yet supported for torch. " + "Please provide a `keras.distribution.TensorLayout` instance." + ) + + return dist.dtensor.distribute_tensor( + tensor.to("cpu"), device_mesh, placements + ) + + +def distribute_data_input(per_process_batch, layout, batch_dim_name): + """Distribute the input data with the corresponding layout. + + Note that the input here is a local worker batch. PyTorch's `from_local` + is used to construct a global DTensor from these local shards. + + Args: + per_process_batch: `torch.Tensor` that is local shard for this process. + layout: `TensorLayout` for the distribution information. + + Returns: + A global batch distributed according to `layout`. + """ + from keras.src.distribution import TensorLayout + + if not isinstance(layout, TensorLayout): + raise ValueError( + "A `keras.distribution.TensorLayout` instance is required." + ) + + placements = layout.backend_layout + device_mesh = layout.device_mesh.backend_mesh + return dist.dtensor.from_local( + per_process_batch, device_mesh, placements, run_check=True + ) + + +def initialize_rng(): + """Initializes the global random number generator across processes. + + This is required for consistent initialization in multi-host settings. + It works by generating a seed on rank 0 and broadcasting it to all other + processes. + """ + global_seed = rng_utils.get_random_seed() + if global_seed is None: + if not dist.is_initialized(): + seed = seed_generator.make_default_seed() + else: + if process_id() == 0: + seed = seed_generator.make_default_seed() + seed_tensor = torch.tensor( + seed, dtype=torch.int64, device="cpu" + ) + else: + seed_tensor = torch.empty(1, dtype=torch.int64, device="cpu") + dist.broadcast(seed_tensor, src=0) + seed = seed_tensor.item() + global_seed = seed + rng_utils.set_random_seed(global_seed) + + global_seed_generator = global_state.get_global_attribute( + "global_seed_generator" + ) + if global_seed_generator is not None and global_seed_generator.seed is None: + global_state.set_global_attribute( + "global_seed_generator", + seed_generator.SeedGenerator( + seed=global_seed, + name=global_seed_generator.name, + backend=global_seed_generator.backend, + ), + ) + + +def initialize(job_addresses, num_processes, process_id): + """Initializes the distributed process group in PyTorch.""" + os.environ["RANK"] = str(process_id) + os.environ["WORLD_SIZE"] = str(num_processes) + + if "," in job_addresses: + master_addr = job_addresses.split(",")[0] + else: + master_addr = job_addresses + + if ":" not in master_addr: + raise ValueError( + "Invalid `job_addresses`. Expected format `hostname:port`, " + f"but got {master_addr}" + ) + + master_host, master_port = master_addr.split(":") + os.environ["MASTER_ADDR"] = master_host + os.environ["MASTER_PORT"] = master_port + + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend) + + initialize_rng() + + +def num_processes(): + """Return the number of processes for the current distribution setting.""" + if dist.is_initialized(): + return dist.get_world_size() + return 1 + + +def process_id(): + """Return the current process ID for the distribution setting.""" + if dist.is_initialized(): + return dist.get_rank() + return 0 + + +def _to_backend_device(device_name): + if isinstance(device_name, torch.device): + return device_name + return torch.device(device_name) + + +def _to_backend_mesh(device_mesh): + """Convert the DeviceMesh to Torch backend specific Mesh. + + Args: + device_mesh: DeviceMesh instance to convert. + + Returns: + A `torch.distributed.DeviceMesh` instance. + """ + mesh_shape = device_mesh.devices.shape + mesh_devices = np.array(device_mesh.devices.flatten()).reshape(mesh_shape) + return dist.DeviceMesh( + device_type="cuda" if torch.cuda.is_available() else "cpu", + mesh=mesh_devices, + ) + + +def _to_backend_layout(tensor_layout): + """Convert the TensorLayout to Torch backend specific placement. + + Args: + tensor_layout: TensorLayout instance to convert. + + Returns: + A list of `torch.distributed.placement_types.Placement` instances. + """ + if tensor_layout.device_mesh is None: + raise ValueError( + "Cannot create sharding when device mesh is not set " + "for TensorLayout." + ) + + mesh_axes = tensor_layout.device_mesh.axis_names + placements = [] + for axis in tensor_layout.axes: + if axis is None: + placements.append(dist.Replicate()) + else: + try: + mesh_dim = mesh_axes.index(axis) + placements.append(dist.Shard(mesh_dim)) + except ValueError: + raise ValueError( + f"Tensor axis `{axis}` is not found in the " + f"device mesh axes `{mesh_axes}`." + ) from None + return placements diff --git a/keras/src/backend/torch/distribution_lib_test.py b/keras/src/backend/torch/distribution_lib_test.py new file mode 100644 index 000000000000..bf4c20403b51 --- /dev/null +++ b/keras/src/backend/torch/distribution_lib_test.py @@ -0,0 +1,156 @@ +import os + +import numpy as np +import pytest +import torch +import torch.distributed as dist + +from keras.src import backend +from keras.src.backend import distribution_lib +from keras.src.distribution import DeviceMesh +from keras.src.distribution import TensorLayout + + +def setup_torch_distributed(): + """ + A fixture to initialize the distributed process group if not already done. + This allows test file to be run directly with `pytest` for single-process + checks, while also working correctly when launched with `torchrun`. + """ + if not dist.is_available() or dist.is_initialized(): + return + + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ.setdefault("MASTER_PORT", "29500") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + dist.init_process_group(backend="gloo") + + +@pytest.mark.skipif( + backend.backend() != "torch", + reason="Backend specific test", +) +class TestTorchDistributionLibLive: + """ + Tests for the Torch distribution library without using mocks. + These tests will reflect the capabilities of environment they are run in. + """ + + def test_device_listing_and_info(self): + """Tests device discovery functions against the runtime environment.""" + if torch.cuda.is_available(): + gpu_devices = distribution_lib.list_devices("gpu") + assert len(gpu_devices) == torch.cuda.device_count() + assert gpu_devices[0] == "cuda:0" + else: + assert distribution_lib.list_devices("gpu") == [] + + cpu_devices = distribution_lib.list_devices("cpu") + assert cpu_devices == ["cpu:0"] + + with pytest.raises(ValueError, match="Unknown device type"): + distribution_lib.list_devices("unsupported_device") + + def test_device_helpers(self): + """Tests validation, backend, and memory info functions.""" + device_str = "cpu:0" + if torch.cuda.is_available(): + device_str = "cuda:0" + + assert distribution_lib.validate_device_placement(device_str) is True + assert distribution_lib.validate_device_placement("invalid:0") is False + + assert distribution_lib.get_device_backend("cpu") == "torch" + assert distribution_lib.get_device_backend("gpu") == "torch" + + mem_info = distribution_lib.get_device_memory_info(device_str) + assert mem_info is not None + assert "type" in mem_info + assert mem_info["index"] == 0 + + def test_process_discovery(self): + """Tests process_id and num_processes in the live environment.""" + rank = distribution_lib.process_id() + world_size = distribution_lib.num_processes() + + if dist.is_initialized(): + assert rank == dist.get_rank() + assert world_size == dist.get_world_size() + else: + assert rank == 0 + assert world_size == 1 + + def test_backend_conversions(self): + """Tests the conversion of Keras objects to Torch backend objects.""" + world_size = distribution_lib.num_processes() + if world_size < 2: + pytest.skip( + "Skipping conversion tests in a single-process environment." + ) + + devices = [f"cpu:{i}" for i in range(world_size)] + shape = (world_size,) + axis_names = ("data",) + keras_mesh = DeviceMesh(shape, axis_names, devices) + + torch_mesh = distribution_lib._to_backend_mesh(keras_mesh) + assert isinstance(torch_mesh, dist.DeviceMesh) + assert torch_mesh.mesh.shape == shape + + keras_layout = TensorLayout(axes=("data",), device_mesh=keras_mesh) + placements = distribution_lib._to_backend_layout(keras_layout) + assert isinstance(placements[0], dist.Shard) + + keras_layout_replicated = TensorLayout( + axes=(None,), device_mesh=keras_mesh + ) + placements_replicated = distribution_lib._to_backend_layout( + keras_layout_replicated + ) + assert isinstance(placements_replicated[0], dist.Replicate) + + def test_tensor_distribution(self): + """Tests the distribution of a tensor into a DTensor.""" + if not dist.is_initialized() or distribution_lib.num_processes() < 2: + pytest.skip( + "Tensor distribution test requires a multi-process environment." + ) + + world_size = distribution_lib.num_processes() + devices = np.arange(world_size) + keras_mesh = DeviceMesh((world_size,), ("batch",), devices) + keras_layout = TensorLayout(("batch", None), keras_mesh) + + local_tensor = torch.randn((10, 20)) + + dtensor = distribution_lib.distribute_tensor(local_tensor, keras_layout) + assert isinstance(dtensor, torch.distributed.dtensor.DTensor) + assert dtensor.device_mesh.mesh.shape == (world_size,) + assert isinstance(dtensor.placements[0], dist.Shard) + + dvariable = distribution_lib.distribute_variable( + local_tensor, keras_layout + ) + assert isinstance(dvariable, torch.distributed.dtensor.DTensor) + + def test_distribute_data_input(self): + """Tests the `from_local` logic for distributing input data.""" + if not dist.is_initialized() or distribution_lib.num_processes() < 2: + pytest.skip( + "Input distribution test requires a multi-process environment." + ) + + world_size = distribution_lib.num_processes() + devices = np.arange(world_size) + keras_mesh = DeviceMesh((world_size,), ("batch",), devices) + keras_layout = TensorLayout(("batch", None), keras_mesh) + + per_process_batch = torch.ones((8, 16)) + + global_batch = distribution_lib.distribute_data_input( + per_process_batch, keras_layout, batch_dim_name="batch" + ) + + assert isinstance(global_batch, torch.distributed.dtensor.DTensor) + assert global_batch.shape == (world_size * 8, 16) diff --git a/keras/src/distribution/tensor_parallel/parameter_sharding.py b/keras/src/distribution/tensor_parallel/parameter_sharding.py new file mode 100644 index 000000000000..f3282968a414 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/parameter_sharding.py @@ -0,0 +1,487 @@ +import re + +from keras.src.backend import distributed_backend + + +class ShardedWeight: + """A wrapper for a sharded Keras Variable to provide a consistent interface. + + This class wraps a tensor shard in a Keras Variable, making it compatible + with the Keras ecosystem. It exposes common variable properties like name, + shape, and trainable status. + """ + + def __init__(self, tensor_shard, name, trainable=True): + """Initializes the ShardedWeight. + + Args: + tensor_shard: The tensor piece (shard) to be managed by this weight. + name (str): The name for the underlying Keras Variable. + trainable (bool, optional): Whether the variable is trainable. + Defaults to True. + """ + import keras + + self._variable = keras.Variable( + initializer=tensor_shard, trainable=trainable, name=name + ) + self.regularizer = None + + @property + def name(self): + """Returns the name of the underlying variable.""" + return self._variable.name + + @property + def trainable(self): + """Returns whether the variable is trainable.""" + return self._variable.trainable + + @property + def shape(self): + """Returns the shape of the variable.""" + return self._variable.shape + + @property + def dtype(self): + """Returns the dtype of the underlying variable.""" + return self._variable.dtype + + @property + def variable(self): + """Provides direct access to the underlying Keras Variable.""" + return self._variable + + @property + def value(self): + """Returns the value of the underlying variable.""" + return self._variable.value + + def numpy(self): + """Returns the value of the variable as a NumPy array.""" + return self._variable.numpy() + + def num_elements(self): + """Returns the total number of elements in the tensor.""" + import keras + + return keras.ops.size(self._variable) + + def __repr__(self): + """Returns a string representation of the ShardedWeight.""" + return ( + f"" + ) + + +class ParameterShardingStrategy: + """Implements parameter-level sharding for a Keras model. + + This strategy shards a model's weights according to a provided configuration + without altering the model's architecture. It identifies weights + that match specific patterns, applies sharding actions to them, and stores + the mapping between original and sharded weights. + """ + + def __init__(self, world_size, rank): + """Initializes the ParameterShardingStrategy. + + Args: + world_size (int): The total number of devices in distributed setup. + rank (int): The rank of the current device. + """ + self.world_size = world_size + self.rank = rank + self.sharded_weights = {} + self.original_weights = {} + self.weight_mapping = {} + self.sharded_weights_by_id = {} + + def shard_model_parameters(self, model, config, device_id): + """Shards model parameters based on a layout configuration. + + This method iterates through the rules in configuration, finds matching + parameters in the model, and applies the specified sharding action. It + then returns a `ParameterShardedModel` wrapper that uses these sharded + weights. + + Args: + model (keras.Model): The Keras model to be sharded. + config (LayoutMap): A configuration object specifying which weights + to shard and how. + device_id: The device identifier for the current process. + + Returns: + tuple: A tuple containing: + - ParameterShardedModel: The wrapped model with sharded + parameters. + - set: A set of names of the parameters that were modified. + """ + ParameterShardedModel = _define_parameter_sharded_model() + + self._store_original_weights(model) + modified_parameters = set() + + for pattern, action in config.state_rules.items(): + if hasattr(action, "__call__"): + matching_params = self._find_matching_parameters(model, pattern) + + for param_name, param in matching_params: + if hasattr(param, "experimental_ref"): + param_id = id(param.experimental_ref()) + else: + param_id = id(param) + + if param_id in self.sharded_weights_by_id: + self.sharded_weights[param_name] = ( + self.sharded_weights_by_id[param_id] + ) + + existing_param_name = "unknown" + for name, shard in self.sharded_weights.items(): + if shard is self.sharded_weights_by_id[param_id]: + existing_param_name = name + break + + self.weight_mapping[param_name] = self.weight_mapping[ + existing_param_name + ] + modified_parameters.add(param_name) + continue + + sharded_param = action(param, self.rank) + + self.sharded_weights[param_name] = sharded_param + self.sharded_weights_by_id[param_id] = sharded_param + + self.weight_mapping[param_name] = { + "original_shape": param.shape, + "sharded_shape": sharded_param.shape, + "action": action, + } + + modified_parameters.add(param_name) + + sharded_model = ParameterShardedModel( + original_model=model, + sharding_strategy=self, + config=config, + device_id=device_id, + ) + + return sharded_model, modified_parameters + + def _store_original_weights(self, model): + """Recursively finds and stores the original weights of a model.""" + from keras.src import layers + + def find_weights_recursive(current_layer, prefix=""): + name = current_layer.name + full_name = f"{prefix}.{name}" if prefix else name + + if hasattr(current_layer, "weights") and current_layer.weights: + for weight in current_layer.weights: + cleaned_name = weight.name.split("/")[-1].split(":")[0] + param_name = f"{full_name}.{cleaned_name}" + self.original_weights[param_name] = weight.numpy() + + if hasattr(current_layer, "layers") and current_layer.layers: + for sub_layer in current_layer.layers: + find_weights_recursive(sub_layer, full_name) + + for attr_name in dir(current_layer): + if attr_name.startswith("__") and attr_name.endswith("__"): + continue + + attr = getattr(current_layer, attr_name) + + if isinstance(attr, layers.Layer) and attr is not current_layer: + find_weights_recursive(attr, full_name) + elif isinstance(attr, (list, tuple)): + for item in attr: + if isinstance(item, layers.Layer): + find_weights_recursive(item, full_name) + + for layer in model.layers: + find_weights_recursive(layer, prefix="") + + def _find_matching_parameters(self, model, pattern): + """Finds model parameters that match a given regex pattern. + + Args: + model (keras.Model): The model to search within. + pattern (str): The regex pattern to match against parameter names. + + Returns: + list: A list of tuples, where each tuple contains the full parameter + name and the corresponding weight object. + """ + from keras.src import layers + + matching_params = [] + processed_layers = set() + + def search_layer_recursive(current_layer, prefix=""): + if id(current_layer) in processed_layers: + return + processed_layers.add(id(current_layer)) + + name = current_layer.name + full_name = f"{prefix}.{name}" if prefix else name + + if hasattr(current_layer, "weights") and current_layer.weights: + for weight in current_layer.weights: + cleaned_weight_name = weight.name.split("/")[-1].split(":")[ + 0 + ] + param_name = f"{full_name}.{cleaned_weight_name}" + + if re.match(pattern, param_name): + matching_params.append((param_name, weight)) + + if hasattr(current_layer, "layers") and current_layer.layers: + for sub_layer in current_layer.layers: + search_layer_recursive(sub_layer, full_name) + + for attr_name in dir(current_layer): + if attr_name.startswith("__") and attr_name.endswith("__"): + continue + + attr = getattr(current_layer, attr_name) + + if isinstance(attr, layers.Layer) and attr is not current_layer: + search_layer_recursive(attr, full_name) + elif isinstance(attr, (list, tuple)): + for item in attr: + if isinstance(item, layers.Layer): + search_layer_recursive(item, full_name) + + search_layer_recursive(model, prefix="") + return matching_params + + +def _define_parameter_sharded_model(): + """Factory function to define and return the ParameterShardedModel class. + + This approach avoids circular import dependencies by defining the class + that inherits from `keras.src.models.Model` inside a function. + + Returns: + The ParameterShardedModel class definition. + """ + from keras.src.models import Model + + class ParameterShardedModel(Model): + """A wrapper model that manages sharded parameters for tensor + parallelism. + + This model wraps an existing Keras model, preserving its original + architecture. It overrides the `weights` property and the `call` method + to handle sharded weights and insert the necessary communication + collectives (e.g., AllReduce, AllGather) during the forward pass. + """ + + def __init__( + self, original_model, sharding_strategy, config, device_id + ): + """Initializes the ParameterShardedModel. + + Args: + original_model: The original, unsharded Keras model. + sharding_strategy: The strategy object + that contains the sharded weights and mappings. + config (LayoutMap): The sharding configuration. + device_id: The device identifier for the current process. + """ + super().__init__() + self.original_model = original_model + self.sharding_strategy = sharding_strategy + self.config = config + self._device = device_id + self._build_and_cache_weights() + if original_model.inputs: + self.build(original_model.inputs[0].shape) + + @property + def device(self): + """Returns the device ID associated with this model shard.""" + return self._device + + def _build_and_cache_weights(self): + """Constructs and caches the definitive list of model weights. + + This method combines newly created `ShardedWeight` objects with any + original weights that were not sharded (i.e., replicated weights). + This combined list is then cached to be returned by the `weights` + property, ensuring the optimizer sees all trainable parameters. + """ + weights_list = [] + sharded_weight_ids = set( + self.sharding_strategy.sharded_weights_by_id.keys() + ) + + for ( + param_name, + sharded_tensor, + ) in self.sharding_strategy.sharded_weights.items(): + weights_list.append(ShardedWeight(sharded_tensor, param_name)) + + for weight in self.original_model.weights: + if hasattr(weight, "experimental_ref"): + weight_id = id(weight.experimental_ref()) + else: + weight_id = id(weight) + + if weight_id not in sharded_weight_ids: + weights_list.append(weight) + + self._weights_list = weights_list + + @property + def weights(self): + """Overrides the base property to return the cached list of weights. + + This list includes both the custom `ShardedWeight` objects and any + unsharded (replicated) weights from the original model. + """ + return self._weights_list + + def call(self, inputs, training=None, mask=None): + """Executes the forward pass of the model with sharded parameters. + + This method manually reconstructs the forward pass of original + model's computation graph. It propagates tensors from one layer to + the next, and after layer's computation, it checks if communication + collective needs to be applied to the output tensor based on the + sharding configuration. + + Args: + inputs: Input tensor(s). + training (bool, optional): Indicates whether the model is in + training mode. Defaults to None. + mask: Mask tensor(s). Defaults to None. + + Returns: + The final output tensor(s) of the model. + """ + from keras.src import layers + + tensor_cache = {} + + if isinstance(inputs, dict): + for inp_tensor in self.original_model.inputs: + tensor_cache[id(inp_tensor)] = inputs[inp_tensor.name] + else: + tensor_cache[id(self.original_model.inputs[0])] = inputs + + for layer in self.original_model.layers: + if isinstance(layer, layers.InputLayer): + continue + + layer_inputs = [] + for node in layer._inbound_nodes: + for symbolic_input_tensor in node.input_tensors: + layer_inputs.append( + tensor_cache[id(symbolic_input_tensor)] + ) + + if len(layer_inputs) == 1: + layer_inputs = layer_inputs[0] + + current_tensor = layer(layer_inputs, training=training) + tensor_cache[id(layer.output)] = current_tensor + + layer_path = layer.path + output_rule = None + for pattern, rule in self.config.output_rules.items(): + if re.search(pattern, layer_path): + output_rule = rule.get(0) + break + if output_rule: + current_tensor = self._apply_communication( + current_tensor, layer.name, output_rule + ) + tensor_cache[id(layer.output)] = current_tensor + + final_outputs = [] + for symbolic_output in self.original_model.outputs: + final_outputs.append(tensor_cache[id(symbolic_output)]) + + if len(final_outputs) == 1: + return final_outputs[0] + return final_outputs + + def _apply_communication(self, sharded_output, layer_name, rule_str): + """Applies a collective communication operation to a tensor. + + This method uses the distributed backend to perform operations like + AllReduce (for summing partial results in row-parallel layouts) or + AllGather (for combining results in column-parallel layouts). + + Args: + sharded_output: The tensor to apply the communication op to. + layer_name (str): The name of the layer producing the output. + rule_str (str): A string from config describing the operation + (e.g., 'allreduce sum', 'allgather -1'). + + Returns: + The tensor after the communication operation has been applied. + """ + comm_ops = distributed_backend.get_communication_ops() + + if "sum" in rule_str or "allreduce" in rule_str: + return comm_ops["all_reduce"]( + sharded_output, op="sum", axis_name="model" + ) + elif "gather" in rule_str: + parts = rule_str.split(" ") + last_part = parts[-1] + if len(parts) > 1 and ( + last_part.isdigit() + or (last_part.startswith("-") and last_part[1:].isdigit()) + ): + dim = int(last_part) + else: + dim = -1 + return comm_ops["all_gather"]( + sharded_output, axis=dim, axis_name="model" + ) + else: + return sharded_output + + def get_config(self): + """Returns the configuration of the original model.""" + return self.original_model.get_config() + + @classmethod + def from_config(cls, config, custom_objects=None): + """Creates a model from its configuration.""" + return cls(**config) + + return ParameterShardedModel + + +def make_parameter_sharded_model(module, config, rank, world_size, device_id): + """Creates a parameter-sharded version of a Keras model. + + This is the main entry point for applying parameter sharding. It initializes + the sharding strategy and uses it to transform the given model. + + Args: + module (keras.Model): The Keras model to shard. + config (LayoutMap): The configuration defining the sharding rules. + rank (int): The rank of the current device. + world_size (int): The total number of devices. + device_id: The identifier for the current device. + + Returns: + tuple: A tuple containing: + - ParameterShardedModel: The new, sharded model wrapper. + - set: A set of names of the parameters that were sharded. + """ + sharding_strategy = ParameterShardingStrategy(world_size, rank) + sharded_model, modified_parameters = ( + sharding_strategy.shard_model_parameters(module, config, device_id) + ) + return sharded_model, modified_parameters diff --git a/keras/src/distribution/tensor_parallel/parameter_sharding_test.py b/keras/src/distribution/tensor_parallel/parameter_sharding_test.py new file mode 100644 index 000000000000..c39507c77365 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/parameter_sharding_test.py @@ -0,0 +1,141 @@ +import os + +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" +import re + +import numpy as np +import pytest + +import keras +from keras import distribution +from keras.src import backend +from keras.src.distribution.tensor_parallel.parameter_sharding import ( + ShardedWeight, +) +from keras.src.distribution.tensor_parallel.parameter_sharding import ( + make_parameter_sharded_model, +) +from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap +from keras.src.distribution.tensor_parallel.tensor_layout import Split +from keras.src.testing import TestCase + + +def _create_simple_mlp(): + inputs = keras.Input(shape=(16,), name="input") + x = keras.layers.Dense(32, use_bias=True, name="up_proj")(inputs) + x = keras.layers.Activation("relu")(x) + outputs = keras.layers.Dense(8, use_bias=False, name="down_proj")(x) + return keras.Model(inputs=inputs, outputs=outputs, name="simple_mlp") + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="This test is for the JAX backend only.", +) +class ParameterShardingTest(TestCase): + def setUp(self): + super().setUp() + import logging + + logging.getLogger().setLevel(logging.ERROR) + + self.world_size = 2 + all_devices = distribution.list_devices() + self.devices = all_devices[: self.world_size] + if len(self.devices) < self.world_size: + self.skipTest( + f"""Not enough devices to run TP test. + Found {len(self.devices)}, need {self.world_size}""" + ) + + self.original_model = _create_simple_mlp() + self.original_model.build(input_shape=(None, 16)) + + self.tp_config = LayoutMap( + state_rules={ + re.escape("simple_mlp.up_proj.kernel"): Split( + self.world_size, dim=1 + ), + re.escape("simple_mlp.down_proj.kernel"): Split( + self.world_size, dim=0 + ), + }, + output_rules={}, + ) + self.input_data = np.random.rand(4, 16).astype("float32") + self.labels = np.random.rand(4, 8).astype("float32") + + def test_model_sharding_creation_and_weight_counts(self): + sharded_models = [] + for rank in range(self.world_size): + with keras.device(self.devices[rank]): + sharded_model, modified_params = make_parameter_sharded_model( + self.original_model, + self.tp_config, + rank=rank, + world_size=self.world_size, + device_id=self.devices[rank], + ) + self.assertIsInstance(sharded_model, keras.Model) + self.assertIn("simple_mlp.up_proj.kernel", modified_params) + self.assertIn("simple_mlp.down_proj.kernel", modified_params) + sharded_models.append(sharded_model) + self.assertEqual( + len(self.original_model.weights), len(sharded_models[0].weights) + ) + + def test_sharded_weight_shapes(self): + rank = 0 + with keras.device(self.devices[rank]): + sharded_model, _ = make_parameter_sharded_model( + self.original_model, + self.tp_config, + rank=rank, + world_size=self.world_size, + device_id=self.devices[rank], + ) + original_weights_dict = {w.path: w for w in self.original_model.weights} + sharded_weights_dict = { + w.name if isinstance(w, ShardedWeight) else w.path: w + for w in sharded_model.weights + } + orig_up_kernel = original_weights_dict["up_proj/kernel"] + shard_up_kernel = sharded_weights_dict["simple_mlp.up_proj.kernel"] + self.assertEqual(shard_up_kernel.shape[0], orig_up_kernel.shape[0]) + self.assertEqual( + shard_up_kernel.shape[1], + orig_up_kernel.shape[1] // self.world_size, + ) + orig_down_kernel = original_weights_dict["down_proj/kernel"] + shard_down_kernel = sharded_weights_dict["simple_mlp.down_proj.kernel"] + self.assertEqual( + shard_down_kernel.shape[0], + orig_down_kernel.shape[0] // self.world_size, + ) + self.assertEqual(shard_down_kernel.shape[1], orig_down_kernel.shape[1]) + + def test_forward_pass_correctness(self): + expected_output = self.original_model(self.input_data) + sharded_outputs = [] + original_weights = self.original_model.get_weights() + for rank in range(self.world_size): + with keras.device(self.devices[rank]): + cloned_original = keras.models.clone_model(self.original_model) + cloned_original.set_weights(original_weights) + sharded_model, _ = make_parameter_sharded_model( + cloned_original, + self.tp_config, + rank=rank, + world_size=self.world_size, + device_id=self.devices[rank], + ) + output = sharded_model(self.input_data) + sharded_outputs.append(output) + reconstructed_output = ( + keras.ops.sum(keras.ops.stack(sharded_outputs), axis=0) + / self.world_size + ) + + self.assertAllClose( + expected_output, reconstructed_output, atol=1e-5, rtol=1e-5 + ) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 11e4046c7b8a..ba4abbe1139a 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -39,7 +39,8 @@ from keras.src.backend.common.remat import get_current_remat_mode from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.backend.config import is_nnx_enabled -from keras.src.distribution import distribution_lib + +# from keras.src.distribution import distribution_lib from keras.src.dtype_policies import DTypePolicyMap from keras.src.layers import input_spec from keras.src.metrics.metric import Metric @@ -942,6 +943,8 @@ def maybe_convert(x): # Change the layout for the layer output if needed. # This is useful for relayout intermediate tensor in the model # to achieve the optimal performance. + from keras.src.distribution import distribution_lib + distribution = distribution_lib.distribution() if distribution is not None: current_layer_path = current_path()