@@ -7762,3 +7762,104 @@ def histogram(x, bins=10, range=None):
77627762 f"Received: input.shape={ x .shape } "
77637763 )
77647764 return backend .numpy .histogram (x , bins = bins , range = range )
7765+
7766+
7767+ class ArraySplit (Operation ):
7768+ def __init__ (self , indices_or_sections , axis = 0 , * , name = None ):
7769+ super ().__init__ (name = name )
7770+
7771+ self .indices_or_sections = indices_or_sections
7772+ self .axis = axis
7773+
7774+ def call (self , x ):
7775+ return backend .numpy .array_split (
7776+ x ,
7777+ indices_or_sections = self .indices_or_sections ,
7778+ axis = self .axis ,
7779+ )
7780+
7781+ def compute_output_spec (self , x ):
7782+ num_splits = self .indices_or_sections
7783+
7784+ axis = self .axis
7785+ if axis < 0 :
7786+ axis += len (x .shape )
7787+
7788+ total_size = x .shape [axis ]
7789+
7790+ if total_size is None :
7791+ output_specs = []
7792+ base_shape = list (x .shape )
7793+ base_shape [axis ] = None
7794+ for _ in range (num_splits ):
7795+ output_specs .append (
7796+ KerasTensor (shape = tuple (base_shape ), dtype = x .dtype )
7797+ )
7798+ return tuple (output_specs )
7799+
7800+ split_size = total_size // num_splits
7801+ remainder = total_size % num_splits
7802+
7803+ output_specs = []
7804+ base_shape = list (x .shape )
7805+ for i in range (num_splits ):
7806+ size = split_size + (1 if i < remainder else 0 )
7807+ shape = base_shape .copy ()
7808+ shape [axis ] = size
7809+ output_specs .append (KerasTensor (shape = tuple (shape ), dtype = x .dtype ))
7810+
7811+ return list (output_specs )
7812+
7813+
7814+ @keras_export (["keras.ops.array_split" , "keras.ops.numpy.array_split" ])
7815+ def array_split (x , indices_or_sections , axis = 0 ):
7816+ """Splits an array into multiple sub-arrays (unevenly).
7817+
7818+ This is similar to `keras.ops.split`, but it allows for
7819+ unequal splits. `indices_or_sections` must be an integer
7820+ that indicates the total number of sub-arrays to create.
7821+ If the tensor cannot be divided evenly, the first `remainder`
7822+ splits will have size `quotient + 1`, and the rest will
7823+ have size `quotient`.
7824+
7825+ Args:
7826+ x: Input tensor.
7827+ indices_or_sections: An integer indicating the number of
7828+ sub-arrays to create.
7829+ axis: The axis along which to split. Defaults to 0.
7830+
7831+ Returns:
7832+ A list of sub-tensors.
7833+
7834+ Example:
7835+ >>> x = keras.ops.arange(10)
7836+ >>> keras.ops.array_split(x, 3)
7837+ (array([0, 1, 2, 3], dtype=int32),
7838+ array([4, 5, 6], dtype=int32),
7839+ array([7, 8, 9], dtype=int32))
7840+ """
7841+ if not isinstance (indices_or_sections , int ):
7842+ raise TypeError (
7843+ "Argument `indices_or_sections` must be of type `int`. "
7844+ f"Received: indices_or_sections={ indices_or_sections } "
7845+ )
7846+
7847+ if indices_or_sections <= 0 :
7848+ raise ValueError (
7849+ "Argument `indices_or_sections` must be a positive integer. "
7850+ f"Received: indices_or_sections={ indices_or_sections } "
7851+ )
7852+
7853+ if not isinstance (axis , int ):
7854+ raise TypeError (
7855+ f"Argument `axis` must be of type `int`. Received: { axis } "
7856+ )
7857+
7858+ if any_symbolic_tensors ((x ,)):
7859+ return ArraySplit (
7860+ indices_or_sections = indices_or_sections , axis = axis
7861+ ).symbolic_call (x )
7862+
7863+ return backend .numpy .array_split (
7864+ x , indices_or_sections = indices_or_sections , axis = axis
7865+ )
0 commit comments