diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 9578ed614a9..15a0d67a422 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -140,6 +140,7 @@ from keras.src.ops.numpy import argpartition as argpartition from keras.src.ops.numpy import argsort as argsort from keras.src.ops.numpy import array as array +from keras.src.ops.numpy import array_split as array_split from keras.src.ops.numpy import average as average from keras.src.ops.numpy import bartlett as bartlett from keras.src.ops.numpy import bincount as bincount diff --git a/keras/api/_tf_keras/keras/ops/numpy/__init__.py b/keras/api/_tf_keras/keras/ops/numpy/__init__.py index f4e450aef7d..9a1d473cac0 100644 --- a/keras/api/_tf_keras/keras/ops/numpy/__init__.py +++ b/keras/api/_tf_keras/keras/ops/numpy/__init__.py @@ -26,6 +26,7 @@ from keras.src.ops.numpy import argpartition as argpartition from keras.src.ops.numpy import argsort as argsort from keras.src.ops.numpy import array as array +from keras.src.ops.numpy import array_split as array_split from keras.src.ops.numpy import average as average from keras.src.ops.numpy import bartlett as bartlett from keras.src.ops.numpy import bincount as bincount diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 9578ed614a9..15a0d67a422 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -140,6 +140,7 @@ from keras.src.ops.numpy import argpartition as argpartition from keras.src.ops.numpy import argsort as argsort from keras.src.ops.numpy import array as array +from keras.src.ops.numpy import array_split as array_split from keras.src.ops.numpy import average as average from keras.src.ops.numpy import bartlett as bartlett from keras.src.ops.numpy import bincount as bincount diff --git a/keras/api/ops/numpy/__init__.py b/keras/api/ops/numpy/__init__.py index f4e450aef7d..9a1d473cac0 100644 --- a/keras/api/ops/numpy/__init__.py +++ b/keras/api/ops/numpy/__init__.py @@ -26,6 +26,7 @@ from keras.src.ops.numpy import argpartition as argpartition from keras.src.ops.numpy import argsort as argsort from keras.src.ops.numpy import array as array +from keras.src.ops.numpy import array_split as array_split from keras.src.ops.numpy import average as average from keras.src.ops.numpy import bartlett as bartlett from keras.src.ops.numpy import bincount as bincount diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index 9b04a317ac4..c284b21cdae 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -1167,6 +1167,11 @@ def split(x, indices_or_sections, axis=0): return jnp.split(x, indices_or_sections, axis=axis) +def array_split(x, indices_or_sections, axis=0): + x = convert_to_tensor(x) + return jnp.array_split(x, indices_or_sections, axis=axis) + + def stack(x, axis=0): x = [convert_to_tensor(t) for t in x] return jnp.stack(x, axis=axis) diff --git a/keras/src/backend/numpy/numpy.py b/keras/src/backend/numpy/numpy.py index fa44a5537ac..e5f4284b3db 100644 --- a/keras/src/backend/numpy/numpy.py +++ b/keras/src/backend/numpy/numpy.py @@ -1107,6 +1107,11 @@ def split(x, indices_or_sections, axis=0): return np.split(x, indices_or_sections, axis=axis) +def array_split(x, indices_or_sections, axis=0): + axis = standardize_axis_for_numpy(axis) + return np.array_split(x, indices_or_sections, axis=axis) + + def stack(x, axis=0): axis = standardize_axis_for_numpy(axis) dtype_set = set([getattr(a, "dtype", type(a)) for a in x]) diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index ae452910db7..a042471aa87 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -2014,6 +2014,37 @@ def split(x, indices_or_sections, axis=0): ) +def array_split(x, indices_or_sections, axis=0): + original_shape = x.shape + x = get_ov_output(x) + + num_splits_val = indices_or_sections + total_size = original_shape[axis] + if total_size is None: + raise ValueError( + f"Cannot use array_split with static Python logic on dynamic axis. " + f"Axis {axis} has unknown dimension for shape {original_shape}." + ) + + base_size = total_size // num_splits_val + remainder = total_size % num_splits_val + + split_lengths = [base_size + 1] * remainder + [base_size] * ( + num_splits_val - remainder + ) + split_lengths_tensor = ov_opset.constant( + split_lengths, dtype=Type.i64 + ).output(0) + + axis_tensor = ov_opset.constant(axis, dtype=Type.i32).output(0) + splits = ov_opset.variadic_split(x, axis_tensor, split_lengths_tensor) + + result = [] + for i in range(num_splits_val): + result.append(OpenVINOKerasTensor(splits.output(i))) + return result + + def stack(x, axis=0): if isinstance(x, tuple): x = list(x) diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index d6f54719bfc..f127b22717b 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -2494,6 +2494,17 @@ def split(x, indices_or_sections, axis=0): return tf.split(x, num_or_size_splits, axis=axis) +def array_split(x, indices_or_sections, axis=0): + x = tf.convert_to_tensor(x) + num_splits = indices_or_sections + total_size = shape_op(x)[axis] + avg_size = total_size // num_splits + remainder = total_size % num_splits + sizes = [avg_size + 1] * remainder + [avg_size] * (num_splits - remainder) + + return tf.split(x, sizes, axis=axis) + + def stack(x, axis=0): dtype_set = set([getattr(a, "dtype", type(a)) for a in x]) if len(dtype_set) > 1: diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py index d3dd3d09f80..cfd844f24b6 100644 --- a/keras/src/backend/torch/numpy.py +++ b/keras/src/backend/torch/numpy.py @@ -1539,6 +1539,12 @@ def split(x, indices_or_sections, axis=0): return list(out) +def array_split(x, indices_or_sections, axis=0): + x = convert_to_tensor(x) + out = torch.tensor_split(x, indices_or_sections, dim=axis) + return list(out) + + def stack(x, axis=0): x = [convert_to_tensor(elem) for elem in x] return torch.stack(x, dim=axis) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index 63e682b3332..3e128912c94 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -7762,3 +7762,104 @@ def histogram(x, bins=10, range=None): f"Received: input.shape={x.shape}" ) return backend.numpy.histogram(x, bins=bins, range=range) + + +class ArraySplit(Operation): + def __init__(self, indices_or_sections, axis=0, *, name=None): + super().__init__(name=name) + + self.indices_or_sections = indices_or_sections + self.axis = axis + + def call(self, x): + return backend.numpy.array_split( + x, + indices_or_sections=self.indices_or_sections, + axis=self.axis, + ) + + def compute_output_spec(self, x): + num_splits = self.indices_or_sections + + axis = self.axis + if axis < 0: + axis += len(x.shape) + + total_size = x.shape[axis] + + if total_size is None: + output_specs = [] + base_shape = list(x.shape) + base_shape[axis] = None + for _ in range(num_splits): + output_specs.append( + KerasTensor(shape=tuple(base_shape), dtype=x.dtype) + ) + return tuple(output_specs) + + split_size = total_size // num_splits + remainder = total_size % num_splits + + output_specs = [] + base_shape = list(x.shape) + for i in range(num_splits): + size = split_size + (1 if i < remainder else 0) + shape = base_shape.copy() + shape[axis] = size + output_specs.append(KerasTensor(shape=tuple(shape), dtype=x.dtype)) + + return list(output_specs) + + +@keras_export(["keras.ops.array_split", "keras.ops.numpy.array_split"]) +def array_split(x, indices_or_sections, axis=0): + """Splits an array into multiple sub-arrays (unevenly). + + This is similar to `keras.ops.split`, but it allows for + unequal splits. `indices_or_sections` must be an integer + that indicates the total number of sub-arrays to create. + If the tensor cannot be divided evenly, the first `remainder` + splits will have size `quotient + 1`, and the rest will + have size `quotient`. + + Args: + x: Input tensor. + indices_or_sections: An integer indicating the number of + sub-arrays to create. + axis: The axis along which to split. Defaults to 0. + + Returns: + A list of sub-tensors. + + Example: + >>> x = keras.ops.arange(10) + >>> keras.ops.array_split(x, 3) + (array([0, 1, 2, 3], dtype=int32), + array([4, 5, 6], dtype=int32), + array([7, 8, 9], dtype=int32)) + """ + if not isinstance(indices_or_sections, int): + raise TypeError( + "Argument `indices_or_sections` must be of type `int`. " + f"Received: indices_or_sections={indices_or_sections}" + ) + + if indices_or_sections <= 0: + raise ValueError( + "Argument `indices_or_sections` must be a positive integer. " + f"Received: indices_or_sections={indices_or_sections}" + ) + + if not isinstance(axis, int): + raise TypeError( + f"Argument `axis` must be of type `int`. Received: {axis}" + ) + + if any_symbolic_tensors((x,)): + return ArraySplit( + indices_or_sections=indices_or_sections, axis=axis + ).symbolic_call(x) + + return backend.numpy.array_split( + x, indices_or_sections=indices_or_sections, axis=axis + ) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 42976bdcbcc..708bb74f562 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -1869,6 +1869,13 @@ def test_view(self): self.assertEqual(knp.view(x, dtype="int32").shape, (None, 2)) self.assertEqual(knp.view(x, dtype="int32").dtype, "int32") + def test_array_split(self): + x = KerasTensor((None, 4)) + splits = knp.array_split(x, 2, axis=0) + self.assertEqual(len(splits), 2) + self.assertEqual(splits[0].shape, (None, 4)) + self.assertEqual(splits[1].shape, (None, 4)) + class NumpyOneInputOpsStaticShapeTest(testing.TestCase): def test_mean(self): @@ -2481,6 +2488,14 @@ def test_view(self): self.assertEqual(knp.view(x, dtype="int32").shape, (2, 2)) self.assertEqual(knp.view(x, dtype="int32").dtype, "int32") + def test_array_split(self): + x = KerasTensor((8, 4)) + splits = knp.array_split(x, 3, axis=0) + self.assertEqual(len(splits), 3) + self.assertEqual(splits[0].shape, (3, 4)) + self.assertEqual(splits[1].shape, (3, 4)) + self.assertEqual(splits[2].shape, (2, 4)) + class NumpyTwoInputOpsCorrectnessTest(testing.TestCase): def test_add(self): @@ -3601,6 +3616,30 @@ def test_mean(self): x = np.array([65504, 65504, 65504], dtype="float16") self.assertAllClose(knp.mean(x), np.mean(x)) + def test_array_split(self): + x = np.array([[1, 2, 3], [4, 5, 6]]) + + # Even split (axis=0) + knp_res1 = knp.array_split(x, 2) + np_res1 = np.array_split(x, 2) + self.assertEqual(len(knp_res1), len(np_res1)) + for k_arr, n_arr in zip(knp_res1, np_res1): + self.assertAllClose(k_arr, n_arr) + + # Even split (axis=1) + knp_res2 = knp.array_split(x, 3, axis=1) + np_res2 = np.array_split(x, 3, axis=1) + self.assertEqual(len(knp_res2), len(np_res2)) + for k_arr, n_arr in zip(knp_res2, np_res2): + self.assertAllClose(k_arr, n_arr) + + # Uneven split (axis=1) - 3 columns into 2 sections + knp_res3 = knp.array_split(x, 2, axis=1) + np_res3 = np.array_split(x, 2, axis=1) + self.assertEqual(len(knp_res3), len(np_res3)) + for k_arr, n_arr in zip(knp_res3, np_res3): + self.assertAllClose(k_arr, n_arr) + def test_all(self): x = np.array([[True, False, True], [True, True, True]]) self.assertAllClose(knp.all(x), np.all(x)) @@ -5920,6 +5959,19 @@ def test_add(self, dtypes): expected_dtype, ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_array_split(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1, 2), dtype=dtype) + x_jax = jnp.ones((1, 2), dtype=dtype) + expected_dtype = standardize_dtype(jnp.split(x_jax, 2, -1)[0].dtype) + + self.assertEqual( + standardize_dtype(knp.split(x, 2, -1)[0].dtype), + expected_dtype, + ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) def test_add_python_types(self, dtype): import jax.numpy as jnp