Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
a27367a
Added tensor parallel for keras (Part 1/3)
buildwithsuhana Sep 26, 2025
488cd8f
Removed unnecessary lines
buildwithsuhana Sep 26, 2025
71ddd1a
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
bc4e4e2
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
d4200b5
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
21f89a2
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
299bd45
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
da625e1
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
c233b8c
Fixing the failing test
buildwithsuhana Sep 26, 2025
7b8d733
Fixing the failing test
buildwithsuhana Sep 26, 2025
f825cd3
Fixing test
buildwithsuhana Sep 26, 2025
3725180
Adding tests for distributed_backends
buildwithsuhana Sep 29, 2025
a6c8a96
Modifications for failing tests
buildwithsuhana Sep 29, 2025
3fabfde
Modified for failing test
buildwithsuhana Sep 29, 2025
b133752
Modified for failing test
buildwithsuhana Sep 29, 2025
83c2e3f
Modified for failing test
buildwithsuhana Sep 29, 2025
3f3be6b
added debuggers
buildwithsuhana Sep 29, 2025
be325ab
removed debuggers
buildwithsuhana Sep 29, 2025
e1282ac
Merge branch 'keras-team:master' into Tensor_parallel_keras
buildwithsuhana Sep 29, 2025
fc11aaa
Removed the tensorflow, numpy and torch backends
buildwithsuhana Sep 30, 2025
ef6e2a0
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Sep 30, 2025
bea6ffa
Refactoring the code
buildwithsuhana Sep 30, 2025
4e00245
Refactoring the code
buildwithsuhana Sep 30, 2025
2f973b0
refactoring
buildwithsuhana Sep 30, 2025
bdb2b84
Adding necessary docstrings
buildwithsuhana Sep 30, 2025
d77fa71
Merge branch 'keras-team:master' into Tensor_parallel_keras
buildwithsuhana Oct 1, 2025
b9990b0
Removing redundancies
buildwithsuhana Oct 3, 2025
0aeee6f
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Oct 3, 2025
f784956
Modifying tests
buildwithsuhana Oct 3, 2025
8895a78
Reformatting
buildwithsuhana Oct 3, 2025
fe97f3b
Reformatting the code
buildwithsuhana Oct 3, 2025
77f01aa
Fixing failing tests
buildwithsuhana Oct 3, 2025
7080328
fixes
buildwithsuhana Oct 3, 2025
af711fd
Fixing tests
buildwithsuhana Oct 3, 2025
97dde17
formatting
buildwithsuhana Oct 3, 2025
f322a97
fixing test
buildwithsuhana Oct 3, 2025
5269ac9
fixing test
buildwithsuhana Oct 3, 2025
b9f36e9
Removing redundant lines
buildwithsuhana Oct 6, 2025
555e5c9
Refactoring to remove communications.py and state_action_keras.py
buildwithsuhana Oct 12, 2025
b80d264
formatting the files
buildwithsuhana Oct 12, 2025
93b1738
fixing skip issues
buildwithsuhana Oct 12, 2025
b7b2b9b
fixing test
buildwithsuhana Oct 12, 2025
f6c1142
fixing test
buildwithsuhana Oct 12, 2025
669c799
refactoring to remove distributed backend wrapper
buildwithsuhana Oct 13, 2025
cd20b9f
fixing test
buildwithsuhana Oct 13, 2025
cd0049f
making distrubed backend more jax friendly
buildwithsuhana Oct 13, 2025
d1e4c69
Fixing comments
buildwithsuhana Oct 17, 2025
86e0557
Fixing comments
buildwithsuhana Oct 17, 2025
6c3883f
Fixing comments
buildwithsuhana Oct 17, 2025
3e31e1e
fixes
buildwithsuhana Oct 17, 2025
c99601e
Refactor
buildwithsuhana Oct 18, 2025
dbae56d
refactoring to resolve comments
buildwithsuhana Oct 18, 2025
2fc0f0e
fixes
buildwithsuhana Oct 18, 2025
174093c
fixes
buildwithsuhana Oct 18, 2025
7d18b0a
fix
buildwithsuhana Oct 18, 2025
f570925
fix
buildwithsuhana Oct 18, 2025
9e7f873
removing get_best_devices
buildwithsuhana Oct 21, 2025
5136091
fixing comments
buildwithsuhana Oct 26, 2025
8f40c53
Merge branch 'master' into Tensor_parallel_keras
buildwithsuhana Oct 26, 2025
08b8abe
fixing merge conflict
buildwithsuhana Oct 26, 2025
3a408da
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Oct 26, 2025
eb796ea
modifying variable name
buildwithsuhana Oct 26, 2025
15e1709
Fixes
buildwithsuhana Oct 28, 2025
911b96e
fix
buildwithsuhana Oct 28, 2025
bd2f19f
fix
buildwithsuhana Oct 28, 2025
71d079f
splitting into 3 PRs
buildwithsuhana Oct 28, 2025
7789084
Modified array_split implementation in openvino, tensorflow and torch
buildwithsuhana Oct 29, 2025
162e6c3
formatting the array split function
buildwithsuhana Oct 29, 2025
d47e3e6
adding test for uneven array split
buildwithsuhana Oct 30, 2025
f4f723d
fixing test
buildwithsuhana Oct 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras/src/backend/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base import BaseDistributedBackend
from .factory import get_distributed_backend

__all__ = ["get_distributed_backend", "BaseDistributedBackend"]
57 changes: 57 additions & 0 deletions keras/src/backend/distributed/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from abc import ABC
from abc import abstractmethod
from typing import Any
from typing import List


class BaseDistributedBackend(ABC):
"""
Abstract Base Class for a distributed backend.
"""

@abstractmethod
def get_tensor_lib(self):
"""Get the appropriate tensor library for the backend."""
raise NotImplementedError

@abstractmethod
def convert_to_backend_tensor(self, tensor: Any) -> Any:
"""Convert a tensor to the appropriate backend format."""
raise NotImplementedError

@abstractmethod
def compute_gradients(
self, loss: Any, trainable_vars: List[Any]
) -> List[Any]:
"""Compute gradients using the backend's automatic differentiation."""
raise NotImplementedError

@abstractmethod
def apply_gradients(
self,
gradients: List[Any],
trainable_vars: List[Any],
learning_rate: float = 0.001,
) -> None:
"""Apply gradients to trainable variables."""
raise NotImplementedError

@abstractmethod
def create_optimizer(self, optimizer_class: str, **kwargs):
"""Create an optimizer for the backend."""
raise NotImplementedError

@abstractmethod
def get_device_info(self) -> dict:
"""Get information about available devices."""
raise NotImplementedError

@abstractmethod
def is_multi_device_capable(self) -> bool:
"""Check if the backend supports multi-device operations."""
raise NotImplementedError

@abstractmethod
def get_communication_ops(self) -> dict:
"""Get collective communication operations for the backend."""
raise NotImplementedError
50 changes: 50 additions & 0 deletions keras/src/backend/distributed/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import logging

from keras.src.backend.distributed.base import BaseDistributedBackend

from keras.src.backend.jax.distributed_backend import JaxDistributedBackend
from keras.src.backend.numpy.distributed_backend import NumpyDistributedBackend
from keras.src.backend.tensorflow.distributed_backend import (
TensorflowDistributedBackend,
)
from keras.src.backend.torch.distributed_backend import (
PytorchDistributedBackend,
)

logger = logging.getLogger(__name__)


def get_distributed_backend(
backend_name: str = "auto",
) -> BaseDistributedBackend:
"""
Factory to get the best available or a specific distributed backend.
"""
if backend_name == "auto":
try:
logger.info("Auto-detected JAX for distributed backend.")
return JaxDistributedBackend()
except ImportError:
try:
logger.info("Auto-detected TensorFlow for distributed backend.")
return TensorflowDistributedBackend()
except ImportError:
try:
logger.info(
"Auto-detected PyTorch for distributed backend."
)
return PytorchDistributedBackend()
except ImportError:
logger.warning("Using NumPy distributed backend.")
return NumpyDistributedBackend()

elif backend_name == "jax":
return JaxDistributedBackend()
elif backend_name == "tensorflow":
return TensorflowDistributedBackend()
elif backend_name == "pytorch":
return PytorchDistributedBackend()
elif backend_name == "numpy":
return NumpyDistributedBackend()
else:
raise ValueError(f"Unknown distributed backend: {backend_name}")
141 changes: 141 additions & 0 deletions keras/src/backend/jax/distributed_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import logging
from typing import Any
from typing import List

import jax
import jax.lax as lax
import jax.numpy as jnp
import optax

from keras.src.backend.distributed.base import BaseDistributedBackend

logger = logging.getLogger(__name__)


class JaxDistributedBackend(BaseDistributedBackend):
"""JAX-specific implementation of distributed operations."""

def get_tensor_lib(self):
return jnp

def convert_to_backend_tensor(self, tensor: Any) -> Any:
if hasattr(tensor, "numpy"):
return jnp.array(tensor.numpy())
else:
return jnp.array(tensor)

def compute_gradients(
self, loss: Any, trainable_vars: List[Any]
) -> List[Any]:
def safe_convert_to_jax(tensor):
try:
if hasattr(tensor, "numpy"):
if hasattr(tensor, "shape") and tensor.shape is None:
logger.warning("Symbolic tensor detected")
return jnp.array(0.0)
else:
return jnp.array(tensor.numpy())
else:
return jnp.array(tensor)
except Exception as e:
logger.warning(
f"Failed to convert tensor to JAX: {e}, using dummy value"
)
return jnp.array(0.0)

loss_jax = safe_convert_to_jax(loss)
params_jax = [safe_convert_to_jax(param) for param in trainable_vars]

def loss_fn(params):
return loss_jax

try:
gradients = jax.grad(loss_fn)(params_jax)
logger.info(" - JAX gradient computation successful")
return gradients
except Exception as e:
logger.warning(
f"JAX gradient computation failed: {e}, using fallback"
)
return [jnp.zeros_like(param) for param in params_jax]

def apply_gradients(
self,
gradients: List[Any],
trainable_vars: List[Any],
learning_rate: float = 0.001,
) -> None:
for grad, var in zip(gradients, trainable_vars):
if grad is not None:
new_value = var - (learning_rate * grad)
if hasattr(var, "assign"):
var.assign(new_value)

def create_optimizer(self, optimizer_class: str, **kwargs):
if optimizer_class.lower() == "adam":
return optax.adam(**kwargs)
elif optimizer_class.lower() == "sgd":
return optax.sgd(**kwargs)
else:
return optax.adam(learning_rate=0.001)

def get_device_info(self) -> dict:
info = {"backend": "jax", "devices": [], "device_count": 0}
try:
info["devices"] = [str(d) for d in jax.devices()]
info["device_count"] = jax.local_device_count()
except Exception as e:
logger.warning(f"Could not get device info for JAX: {e}")
info["devices"] = ["cpu"]
info["device_count"] = 1
return info

def is_multi_device_capable(self) -> bool:
return self.get_device_info()["device_count"] > 1

def get_communication_ops(self) -> dict:
def all_reduce_jax(x, op="sum", axis_name="data"):
return lax.pmean(x, axis_name=axis_name)

def all_gather_jax(x, axis=0, axis_name="model"):
return lax.all_gather(x, axis_name=axis_name, axis=axis)

def broadcast_jax(x, axis_name="data"):
return lax.all_gather(x, axis_name=axis_name, axis=0)

def scatter_jax(x, num_devices, axis_name="data"):
return lax.psplit(x, axis_name=axis_name, num_splits=num_devices)

def all_reduce_simulated(x, op="sum", axis_name="data"):
return jnp.sum(x, axis=0)

def all_gather_simulated(x, axis=0, axis_name="model"):
return jnp.concatenate([x, x], axis=axis)

def broadcast_simulated(x):
return x

def scatter_simulated(x, num_devices):
return jnp.split(x, num_devices, axis=0)

try:
if jax.device_count() > 1:
logger.info("Using real JAX collective communication ops.")
return {
"all_reduce": all_reduce_jax,
"all_gather": all_gather_jax,
"broadcast": broadcast_jax,
"scatter": scatter_jax,
}
else:
raise RuntimeError("Not running on multiple JAX devices.")
except (ImportError, RuntimeError) as e:
logger.warning(
f"JAX collective ops not available: {e}. Using SIMULATED ops."
)
return {
"all_reduce": all_reduce_simulated,
"all_gather": all_gather_simulated,
"broadcast": broadcast_simulated,
"scatter": scatter_simulated,
}
105 changes: 105 additions & 0 deletions keras/src/backend/numpy/distributed_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import logging
from typing import Any
from typing import List

import numpy as np

import keras
from keras.src.backend.distributed.base import BaseDistributedBackend

logger = logging.getLogger(__name__)


class NumpyDistributedBackend(BaseDistributedBackend):
"""NumPy-based fallback implementation of distributed operations."""

def get_tensor_lib(self):
return np

def convert_to_backend_tensor(self, tensor: Any) -> Any:
return keras.ops.convert_to_numpy(tensor)

def compute_gradients(
self, loss: Any, trainable_vars: List[Any]
) -> List[Any]:
epsilon = 1e-7
gradients = []
for var in trainable_vars:
if hasattr(var, "shape"):
grad = np.zeros_like(var)
it = np.nditer(
var, flags=["multi_index"], op_flags=["readwrite"]
)
while not it.finished:
idx = it.multi_index
original_value = var[idx]
var[idx] = original_value + epsilon
# This part is flawed as loss is a scalar.
# Numerical differentiation needs a function to re-evaluate.
# This is a placeholder for a no-op.
loss_plus = loss
var[idx] = original_value - epsilon
loss_minus = loss
grad[idx] = (loss_plus - loss_minus) / (
2 * epsilon
) # Will be 0
var[idx] = original_value # Restore
it.iternext()
gradients.append(grad)
else:
gradients.append(0.0)
return gradients

def apply_gradients(
self,
gradients: List[Any],
trainable_vars: List[Any],
learning_rate: float = 0.001,
) -> None:
for grad, var in zip(gradients, trainable_vars):
if grad is not None:
new_value = var - (learning_rate * grad)
if hasattr(var, "assign"):
var.assign(new_value)
else:
var[:] = new_value

def create_optimizer(self, optimizer_class: str, **kwargs):
class NumpyOptimizer:
def __init__(self, learning_rate=0.001):
self.learning_rate = learning_rate

def apply_gradients(self, grads_and_vars):
for grad, var in grads_and_vars:
if grad is not None:
var -= self.learning_rate * grad

return NumpyOptimizer(**kwargs)

def get_device_info(self) -> dict:
return {"backend": "numpy", "devices": ["cpu"], "device_count": 1}

def is_multi_device_capable(self) -> bool:
return False

def get_communication_ops(self) -> dict:
logger.info("Using SIMULATED NumPy communication ops.")

def all_reduce_np(x, op="sum"):
return keras.ops.sum(x, axis=0)

def all_gather_np(x, axis=0):
return keras.ops.concatenate([x, x], axis=axis)

def broadcast_np(x):
return x

def scatter_np(x, num_devices):
return keras.ops.split(x, num_devices, axis=0)

return {
"all_reduce": all_reduce_np,
"all_gather": all_gather_np,
"broadcast": broadcast_np,
"scatter": scatter_np,
}
Loading
Loading