Skip to content
Merged
Show file tree
Hide file tree
Changes from 68 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
a27367a
Added tensor parallel for keras (Part 1/3)
buildwithsuhana Sep 26, 2025
488cd8f
Removed unnecessary lines
buildwithsuhana Sep 26, 2025
71ddd1a
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
bc4e4e2
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
d4200b5
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
21f89a2
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
299bd45
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
da625e1
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
c233b8c
Fixing the failing test
buildwithsuhana Sep 26, 2025
7b8d733
Fixing the failing test
buildwithsuhana Sep 26, 2025
f825cd3
Fixing test
buildwithsuhana Sep 26, 2025
3725180
Adding tests for distributed_backends
buildwithsuhana Sep 29, 2025
a6c8a96
Modifications for failing tests
buildwithsuhana Sep 29, 2025
3fabfde
Modified for failing test
buildwithsuhana Sep 29, 2025
b133752
Modified for failing test
buildwithsuhana Sep 29, 2025
83c2e3f
Modified for failing test
buildwithsuhana Sep 29, 2025
3f3be6b
added debuggers
buildwithsuhana Sep 29, 2025
be325ab
removed debuggers
buildwithsuhana Sep 29, 2025
e1282ac
Merge branch 'keras-team:master' into Tensor_parallel_keras
buildwithsuhana Sep 29, 2025
fc11aaa
Removed the tensorflow, numpy and torch backends
buildwithsuhana Sep 30, 2025
ef6e2a0
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Sep 30, 2025
bea6ffa
Refactoring the code
buildwithsuhana Sep 30, 2025
4e00245
Refactoring the code
buildwithsuhana Sep 30, 2025
2f973b0
refactoring
buildwithsuhana Sep 30, 2025
bdb2b84
Adding necessary docstrings
buildwithsuhana Sep 30, 2025
d77fa71
Merge branch 'keras-team:master' into Tensor_parallel_keras
buildwithsuhana Oct 1, 2025
b9990b0
Removing redundancies
buildwithsuhana Oct 3, 2025
0aeee6f
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Oct 3, 2025
f784956
Modifying tests
buildwithsuhana Oct 3, 2025
8895a78
Reformatting
buildwithsuhana Oct 3, 2025
fe97f3b
Reformatting the code
buildwithsuhana Oct 3, 2025
77f01aa
Fixing failing tests
buildwithsuhana Oct 3, 2025
7080328
fixes
buildwithsuhana Oct 3, 2025
af711fd
Fixing tests
buildwithsuhana Oct 3, 2025
97dde17
formatting
buildwithsuhana Oct 3, 2025
f322a97
fixing test
buildwithsuhana Oct 3, 2025
5269ac9
fixing test
buildwithsuhana Oct 3, 2025
b9f36e9
Removing redundant lines
buildwithsuhana Oct 6, 2025
555e5c9
Refactoring to remove communications.py and state_action_keras.py
buildwithsuhana Oct 12, 2025
b80d264
formatting the files
buildwithsuhana Oct 12, 2025
93b1738
fixing skip issues
buildwithsuhana Oct 12, 2025
b7b2b9b
fixing test
buildwithsuhana Oct 12, 2025
f6c1142
fixing test
buildwithsuhana Oct 12, 2025
669c799
refactoring to remove distributed backend wrapper
buildwithsuhana Oct 13, 2025
cd20b9f
fixing test
buildwithsuhana Oct 13, 2025
cd0049f
making distrubed backend more jax friendly
buildwithsuhana Oct 13, 2025
d1e4c69
Fixing comments
buildwithsuhana Oct 17, 2025
86e0557
Fixing comments
buildwithsuhana Oct 17, 2025
6c3883f
Fixing comments
buildwithsuhana Oct 17, 2025
3e31e1e
fixes
buildwithsuhana Oct 17, 2025
c99601e
Refactor
buildwithsuhana Oct 18, 2025
dbae56d
refactoring to resolve comments
buildwithsuhana Oct 18, 2025
2fc0f0e
fixes
buildwithsuhana Oct 18, 2025
174093c
fixes
buildwithsuhana Oct 18, 2025
7d18b0a
fix
buildwithsuhana Oct 18, 2025
f570925
fix
buildwithsuhana Oct 18, 2025
9e7f873
removing get_best_devices
buildwithsuhana Oct 21, 2025
5136091
fixing comments
buildwithsuhana Oct 26, 2025
8f40c53
Merge branch 'master' into Tensor_parallel_keras
buildwithsuhana Oct 26, 2025
08b8abe
fixing merge conflict
buildwithsuhana Oct 26, 2025
3a408da
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Oct 26, 2025
eb796ea
modifying variable name
buildwithsuhana Oct 26, 2025
15e1709
Fixes
buildwithsuhana Oct 28, 2025
911b96e
fix
buildwithsuhana Oct 28, 2025
bd2f19f
fix
buildwithsuhana Oct 28, 2025
71d079f
splitting into 3 PRs
buildwithsuhana Oct 28, 2025
7789084
Modified array_split implementation in openvino, tensorflow and torch
buildwithsuhana Oct 29, 2025
162e6c3
formatting the array split function
buildwithsuhana Oct 29, 2025
d47e3e6
adding test for uneven array split
buildwithsuhana Oct 30, 2025
f4f723d
fixing test
buildwithsuhana Oct 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
from keras.src.ops.numpy import argpartition as argpartition
from keras.src.ops.numpy import argsort as argsort
from keras.src.ops.numpy import array as array
from keras.src.ops.numpy import array_split as array_split
from keras.src.ops.numpy import average as average
from keras.src.ops.numpy import bartlett as bartlett
from keras.src.ops.numpy import bincount as bincount
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from keras.src.ops.numpy import argpartition as argpartition
from keras.src.ops.numpy import argsort as argsort
from keras.src.ops.numpy import array as array
from keras.src.ops.numpy import array_split as array_split
from keras.src.ops.numpy import average as average
from keras.src.ops.numpy import bartlett as bartlett
from keras.src.ops.numpy import bincount as bincount
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
from keras.src.ops.numpy import argpartition as argpartition
from keras.src.ops.numpy import argsort as argsort
from keras.src.ops.numpy import array as array
from keras.src.ops.numpy import array_split as array_split
from keras.src.ops.numpy import average as average
from keras.src.ops.numpy import bartlett as bartlett
from keras.src.ops.numpy import bincount as bincount
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from keras.src.ops.numpy import argpartition as argpartition
from keras.src.ops.numpy import argsort as argsort
from keras.src.ops.numpy import array as array
from keras.src.ops.numpy import array_split as array_split
from keras.src.ops.numpy import average as average
from keras.src.ops.numpy import bartlett as bartlett
from keras.src.ops.numpy import bincount as bincount
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,11 @@ def split(x, indices_or_sections, axis=0):
return jnp.split(x, indices_or_sections, axis=axis)


def array_split(x, indices_or_sections, axis=0):
x = convert_to_tensor(x)
return jnp.array_split(x, indices_or_sections, axis=axis)


def stack(x, axis=0):
x = [convert_to_tensor(t) for t in x]
return jnp.stack(x, axis=axis)
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,11 @@ def split(x, indices_or_sections, axis=0):
return np.split(x, indices_or_sections, axis=axis)


def array_split(x, indices_or_sections, axis=0):
axis = standardize_axis_for_numpy(axis)
return np.array_split(x, indices_or_sections, axis=axis)


def stack(x, axis=0):
axis = standardize_axis_for_numpy(axis)
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
Expand Down
31 changes: 31 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,37 @@ def split(x, indices_or_sections, axis=0):
)


def array_split(x, indices_or_sections, axis=0):
original_shape = x.shape
x = get_ov_output(x)

num_splits_val = indices_or_sections
total_size = original_shape[axis]
if total_size is None:
raise ValueError(
f"Cannot use array_split with static Python logic on dynamic axis. "
f"Axis {axis} has unknown dimension for shape {original_shape}."
)

base_size = total_size // num_splits_val
remainder = total_size % num_splits_val

split_lengths = [base_size + 1] * remainder + [base_size] * (
num_splits_val - remainder
)
split_lengths_tensor = ov_opset.constant(
split_lengths, dtype=Type.i64
).output(0)

axis_tensor = ov_opset.constant(axis, dtype=Type.i32).output(0)
splits = ov_opset.variadic_split(x, axis_tensor, split_lengths_tensor)

result = []
for i in range(num_splits_val):
result.append(OpenVINOKerasTensor(splits.output(i)))
return result


def stack(x, axis=0):
if isinstance(x, tuple):
x = list(x)
Expand Down
11 changes: 11 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2494,6 +2494,17 @@ def split(x, indices_or_sections, axis=0):
return tf.split(x, num_or_size_splits, axis=axis)


def array_split(x, indices_or_sections, axis=0):
x = tf.convert_to_tensor(x)
num_splits = indices_or_sections
total_size = shape_op(x)[axis]
avg_size = total_size // num_splits
remainder = total_size % num_splits
sizes = [avg_size + 1] * remainder + [avg_size] * (num_splits - remainder)

return tf.split(x, sizes, axis=axis)


def stack(x, axis=0):
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
if len(dtype_set) > 1:
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,6 +1539,12 @@ def split(x, indices_or_sections, axis=0):
return list(out)


def array_split(x, indices_or_sections, axis=0):
x = convert_to_tensor(x)
out = torch.tensor_split(x, indices_or_sections, dim=axis)
return list(out)


def stack(x, axis=0):
x = [convert_to_tensor(elem) for elem in x]
return torch.stack(x, dim=axis)
Expand Down
101 changes: 101 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7762,3 +7762,104 @@ def histogram(x, bins=10, range=None):
f"Received: input.shape={x.shape}"
)
return backend.numpy.histogram(x, bins=bins, range=range)


class ArraySplit(Operation):
def __init__(self, indices_or_sections, axis=0, *, name=None):
super().__init__(name=name)

self.indices_or_sections = indices_or_sections
self.axis = axis

def call(self, x):
return backend.numpy.array_split(
x,
indices_or_sections=self.indices_or_sections,
axis=self.axis,
)

def compute_output_spec(self, x):
num_splits = self.indices_or_sections

axis = self.axis
if axis < 0:
axis += len(x.shape)

total_size = x.shape[axis]

if total_size is None:
output_specs = []
base_shape = list(x.shape)
base_shape[axis] = None
for _ in range(num_splits):
output_specs.append(
KerasTensor(shape=tuple(base_shape), dtype=x.dtype)
)
return tuple(output_specs)

split_size = total_size // num_splits
remainder = total_size % num_splits

output_specs = []
base_shape = list(x.shape)
for i in range(num_splits):
size = split_size + (1 if i < remainder else 0)
shape = base_shape.copy()
shape[axis] = size
output_specs.append(KerasTensor(shape=tuple(shape), dtype=x.dtype))

return list(output_specs)


@keras_export(["keras.ops.array_split", "keras.ops.numpy.array_split"])
def array_split(x, indices_or_sections, axis=0):
"""Splits an array into multiple sub-arrays (unevenly).

This is similar to `keras.ops.split`, but it allows for
unequal splits. `indices_or_sections` must be an integer
that indicates the total number of sub-arrays to create.
If the tensor cannot be divided evenly, the first `remainder`
splits will have size `quotient + 1`, and the rest will
have size `quotient`.

Args:
x: Input tensor.
indices_or_sections: An integer indicating the number of
sub-arrays to create.
axis: The axis along which to split. Defaults to 0.

Returns:
A list of sub-tensors.

Example:
>>> x = keras.ops.arange(10)
>>> keras.ops.array_split(x, 3)
(array([0, 1, 2, 3], dtype=int32),
array([4, 5, 6], dtype=int32),
array([7, 8, 9], dtype=int32))
"""
if not isinstance(indices_or_sections, int):
raise TypeError(
"Argument `indices_or_sections` must be of type `int`. "
f"Received: indices_or_sections={indices_or_sections}"
)

if indices_or_sections <= 0:
raise ValueError(
"Argument `indices_or_sections` must be a positive integer. "
f"Received: indices_or_sections={indices_or_sections}"
)

if not isinstance(axis, int):
raise TypeError(
f"Argument `axis` must be of type `int`. Received: {axis}"
)

if any_symbolic_tensors((x,)):
return ArraySplit(
indices_or_sections=indices_or_sections, axis=axis
).symbolic_call(x)

return backend.numpy.array_split(
x, indices_or_sections=indices_or_sections, axis=axis
)
35 changes: 35 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1869,6 +1869,13 @@ def test_view(self):
self.assertEqual(knp.view(x, dtype="int32").shape, (None, 2))
self.assertEqual(knp.view(x, dtype="int32").dtype, "int32")

def test_array_split(self):
x = KerasTensor((None, 4))
splits = knp.array_split(x, 2, axis=0)
self.assertEqual(len(splits), 2)
self.assertEqual(splits[0].shape, (None, 4))
self.assertEqual(splits[1].shape, (None, 4))


class NumpyOneInputOpsStaticShapeTest(testing.TestCase):
def test_mean(self):
Expand Down Expand Up @@ -2481,6 +2488,14 @@ def test_view(self):
self.assertEqual(knp.view(x, dtype="int32").shape, (2, 2))
self.assertEqual(knp.view(x, dtype="int32").dtype, "int32")

def test_array_split(self):
x = KerasTensor((8, 4))
splits = knp.array_split(x, 3, axis=0)
self.assertEqual(len(splits), 3)
self.assertEqual(splits[0].shape, (3, 4))
self.assertEqual(splits[1].shape, (3, 4))
self.assertEqual(splits[2].shape, (2, 4))


class NumpyTwoInputOpsCorrectnessTest(testing.TestCase):
def test_add(self):
Expand Down Expand Up @@ -3601,6 +3616,13 @@ def test_mean(self):
x = np.array([65504, 65504, 65504], dtype="float16")
self.assertAllClose(knp.mean(x), np.mean(x))

def test_array_split(self):
x = np.array([[1, 2, 3], [4, 5, 6]])
self.assertAllClose(knp.array_split(x, 2), np.array_split(x, 2))
self.assertAllClose(
knp.array_split(x, 3, axis=1), np.array_split(x, 3, axis=1)
)

def test_all(self):
x = np.array([[True, False, True], [True, True, True]])
self.assertAllClose(knp.all(x), np.all(x))
Expand Down Expand Up @@ -5920,6 +5942,19 @@ def test_add(self, dtypes):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_array_split(self, dtype):
import jax.numpy as jnp

x = knp.ones((1, 2), dtype=dtype)
x_jax = jnp.ones((1, 2), dtype=dtype)
expected_dtype = standardize_dtype(jnp.split(x_jax, 2, -1)[0].dtype)

self.assertEqual(
standardize_dtype(knp.split(x, 2, -1)[0].dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_add_python_types(self, dtype):
import jax.numpy as jnp
Expand Down