Skip to content

Commit 1519bcc

Browse files
Add keras.ops.array_split for Tensor Parallelism Support (#21697)
* Added tensor parallel for keras (Part 1/3) * Removed unnecessary lines * Fixes suggested by Gemini * Fixes suggested by Gemini * Fixes suggested by Gemini * Fixes suggested by Gemini * Fixes suggested by Gemini * Fixes suggested by Gemini * Fixing the failing test * Fixing the failing test * Fixing test * Adding tests for distributed_backends * Modifications for failing tests * Modified for failing test * Modified for failing test * Modified for failing test * added debuggers * removed debuggers * Removed the tensorflow, numpy and torch backends * Refactoring the code * Refactoring the code * refactoring * Adding necessary docstrings * Removing redundancies * Modifying tests * Reformatting * Reformatting the code * Fixing failing tests * fixes * Fixing tests * formatting * fixing test * fixing test * Removing redundant lines * Refactoring to remove communications.py and state_action_keras.py * formatting the files * fixing skip issues * fixing test * fixing test * refactoring to remove distributed backend wrapper * fixing test * making distrubed backend more jax friendly * Fixing comments * Fixing comments * Fixing comments * fixes * Refactor * refactoring to resolve comments * fixes * fixes * fix * fix * removing get_best_devices * fixing comments * fixing merge conflict * modifying variable name * Fixes * fix * fix * splitting into 3 PRs * Modified array_split implementation in openvino, tensorflow and torch * formatting the array split function * adding test for uneven array split * fixing test
1 parent cf94cdc commit 1519bcc

File tree

11 files changed

+215
-0
lines changed

11 files changed

+215
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
from keras.src.ops.numpy import argpartition as argpartition
141141
from keras.src.ops.numpy import argsort as argsort
142142
from keras.src.ops.numpy import array as array
143+
from keras.src.ops.numpy import array_split as array_split
143144
from keras.src.ops.numpy import average as average
144145
from keras.src.ops.numpy import bartlett as bartlett
145146
from keras.src.ops.numpy import bincount as bincount

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from keras.src.ops.numpy import argpartition as argpartition
2727
from keras.src.ops.numpy import argsort as argsort
2828
from keras.src.ops.numpy import array as array
29+
from keras.src.ops.numpy import array_split as array_split
2930
from keras.src.ops.numpy import average as average
3031
from keras.src.ops.numpy import bartlett as bartlett
3132
from keras.src.ops.numpy import bincount as bincount

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@
140140
from keras.src.ops.numpy import argpartition as argpartition
141141
from keras.src.ops.numpy import argsort as argsort
142142
from keras.src.ops.numpy import array as array
143+
from keras.src.ops.numpy import array_split as array_split
143144
from keras.src.ops.numpy import average as average
144145
from keras.src.ops.numpy import bartlett as bartlett
145146
from keras.src.ops.numpy import bincount as bincount

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from keras.src.ops.numpy import argpartition as argpartition
2727
from keras.src.ops.numpy import argsort as argsort
2828
from keras.src.ops.numpy import array as array
29+
from keras.src.ops.numpy import array_split as array_split
2930
from keras.src.ops.numpy import average as average
3031
from keras.src.ops.numpy import bartlett as bartlett
3132
from keras.src.ops.numpy import bincount as bincount

keras/src/backend/jax/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,11 @@ def split(x, indices_or_sections, axis=0):
11671167
return jnp.split(x, indices_or_sections, axis=axis)
11681168

11691169

1170+
def array_split(x, indices_or_sections, axis=0):
1171+
x = convert_to_tensor(x)
1172+
return jnp.array_split(x, indices_or_sections, axis=axis)
1173+
1174+
11701175
def stack(x, axis=0):
11711176
x = [convert_to_tensor(t) for t in x]
11721177
return jnp.stack(x, axis=axis)

keras/src/backend/numpy/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,11 @@ def split(x, indices_or_sections, axis=0):
11071107
return np.split(x, indices_or_sections, axis=axis)
11081108

11091109

1110+
def array_split(x, indices_or_sections, axis=0):
1111+
axis = standardize_axis_for_numpy(axis)
1112+
return np.array_split(x, indices_or_sections, axis=axis)
1113+
1114+
11101115
def stack(x, axis=0):
11111116
axis = standardize_axis_for_numpy(axis)
11121117
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])

keras/src/backend/openvino/numpy.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2014,6 +2014,37 @@ def split(x, indices_or_sections, axis=0):
20142014
)
20152015

20162016

2017+
def array_split(x, indices_or_sections, axis=0):
2018+
original_shape = x.shape
2019+
x = get_ov_output(x)
2020+
2021+
num_splits_val = indices_or_sections
2022+
total_size = original_shape[axis]
2023+
if total_size is None:
2024+
raise ValueError(
2025+
f"Cannot use array_split with static Python logic on dynamic axis. "
2026+
f"Axis {axis} has unknown dimension for shape {original_shape}."
2027+
)
2028+
2029+
base_size = total_size // num_splits_val
2030+
remainder = total_size % num_splits_val
2031+
2032+
split_lengths = [base_size + 1] * remainder + [base_size] * (
2033+
num_splits_val - remainder
2034+
)
2035+
split_lengths_tensor = ov_opset.constant(
2036+
split_lengths, dtype=Type.i64
2037+
).output(0)
2038+
2039+
axis_tensor = ov_opset.constant(axis, dtype=Type.i32).output(0)
2040+
splits = ov_opset.variadic_split(x, axis_tensor, split_lengths_tensor)
2041+
2042+
result = []
2043+
for i in range(num_splits_val):
2044+
result.append(OpenVINOKerasTensor(splits.output(i)))
2045+
return result
2046+
2047+
20172048
def stack(x, axis=0):
20182049
if isinstance(x, tuple):
20192050
x = list(x)

keras/src/backend/tensorflow/numpy.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2494,6 +2494,17 @@ def split(x, indices_or_sections, axis=0):
24942494
return tf.split(x, num_or_size_splits, axis=axis)
24952495

24962496

2497+
def array_split(x, indices_or_sections, axis=0):
2498+
x = tf.convert_to_tensor(x)
2499+
num_splits = indices_or_sections
2500+
total_size = shape_op(x)[axis]
2501+
avg_size = total_size // num_splits
2502+
remainder = total_size % num_splits
2503+
sizes = [avg_size + 1] * remainder + [avg_size] * (num_splits - remainder)
2504+
2505+
return tf.split(x, sizes, axis=axis)
2506+
2507+
24972508
def stack(x, axis=0):
24982509
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
24992510
if len(dtype_set) > 1:

keras/src/backend/torch/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,6 +1539,12 @@ def split(x, indices_or_sections, axis=0):
15391539
return list(out)
15401540

15411541

1542+
def array_split(x, indices_or_sections, axis=0):
1543+
x = convert_to_tensor(x)
1544+
out = torch.tensor_split(x, indices_or_sections, dim=axis)
1545+
return list(out)
1546+
1547+
15421548
def stack(x, axis=0):
15431549
x = [convert_to_tensor(elem) for elem in x]
15441550
return torch.stack(x, dim=axis)

keras/src/ops/numpy.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7762,3 +7762,104 @@ def histogram(x, bins=10, range=None):
77627762
f"Received: input.shape={x.shape}"
77637763
)
77647764
return backend.numpy.histogram(x, bins=bins, range=range)
7765+
7766+
7767+
class ArraySplit(Operation):
7768+
def __init__(self, indices_or_sections, axis=0, *, name=None):
7769+
super().__init__(name=name)
7770+
7771+
self.indices_or_sections = indices_or_sections
7772+
self.axis = axis
7773+
7774+
def call(self, x):
7775+
return backend.numpy.array_split(
7776+
x,
7777+
indices_or_sections=self.indices_or_sections,
7778+
axis=self.axis,
7779+
)
7780+
7781+
def compute_output_spec(self, x):
7782+
num_splits = self.indices_or_sections
7783+
7784+
axis = self.axis
7785+
if axis < 0:
7786+
axis += len(x.shape)
7787+
7788+
total_size = x.shape[axis]
7789+
7790+
if total_size is None:
7791+
output_specs = []
7792+
base_shape = list(x.shape)
7793+
base_shape[axis] = None
7794+
for _ in range(num_splits):
7795+
output_specs.append(
7796+
KerasTensor(shape=tuple(base_shape), dtype=x.dtype)
7797+
)
7798+
return tuple(output_specs)
7799+
7800+
split_size = total_size // num_splits
7801+
remainder = total_size % num_splits
7802+
7803+
output_specs = []
7804+
base_shape = list(x.shape)
7805+
for i in range(num_splits):
7806+
size = split_size + (1 if i < remainder else 0)
7807+
shape = base_shape.copy()
7808+
shape[axis] = size
7809+
output_specs.append(KerasTensor(shape=tuple(shape), dtype=x.dtype))
7810+
7811+
return list(output_specs)
7812+
7813+
7814+
@keras_export(["keras.ops.array_split", "keras.ops.numpy.array_split"])
7815+
def array_split(x, indices_or_sections, axis=0):
7816+
"""Splits an array into multiple sub-arrays (unevenly).
7817+
7818+
This is similar to `keras.ops.split`, but it allows for
7819+
unequal splits. `indices_or_sections` must be an integer
7820+
that indicates the total number of sub-arrays to create.
7821+
If the tensor cannot be divided evenly, the first `remainder`
7822+
splits will have size `quotient + 1`, and the rest will
7823+
have size `quotient`.
7824+
7825+
Args:
7826+
x: Input tensor.
7827+
indices_or_sections: An integer indicating the number of
7828+
sub-arrays to create.
7829+
axis: The axis along which to split. Defaults to 0.
7830+
7831+
Returns:
7832+
A list of sub-tensors.
7833+
7834+
Example:
7835+
>>> x = keras.ops.arange(10)
7836+
>>> keras.ops.array_split(x, 3)
7837+
(array([0, 1, 2, 3], dtype=int32),
7838+
array([4, 5, 6], dtype=int32),
7839+
array([7, 8, 9], dtype=int32))
7840+
"""
7841+
if not isinstance(indices_or_sections, int):
7842+
raise TypeError(
7843+
"Argument `indices_or_sections` must be of type `int`. "
7844+
f"Received: indices_or_sections={indices_or_sections}"
7845+
)
7846+
7847+
if indices_or_sections <= 0:
7848+
raise ValueError(
7849+
"Argument `indices_or_sections` must be a positive integer. "
7850+
f"Received: indices_or_sections={indices_or_sections}"
7851+
)
7852+
7853+
if not isinstance(axis, int):
7854+
raise TypeError(
7855+
f"Argument `axis` must be of type `int`. Received: {axis}"
7856+
)
7857+
7858+
if any_symbolic_tensors((x,)):
7859+
return ArraySplit(
7860+
indices_or_sections=indices_or_sections, axis=axis
7861+
).symbolic_call(x)
7862+
7863+
return backend.numpy.array_split(
7864+
x, indices_or_sections=indices_or_sections, axis=axis
7865+
)

0 commit comments

Comments
 (0)