diff --git a/Source/MLX/ArrayAt.swift b/Source/MLX/ArrayAt.swift index 8fdb4a6c..d5774788 100644 --- a/Source/MLX/ArrayAt.swift +++ b/Source/MLX/ArrayAt.swift @@ -34,7 +34,7 @@ public struct ArrayAt { /// ### See Also /// - ``MLXArray/at`` /// - ``ArrayAtIndices`` - public subscript(indices: MLXArrayIndex..., stream stream: StreamOrDevice = .default) + public subscript(indices: any MLXArrayIndex..., stream stream: StreamOrDevice = .default) -> ArrayAtIndices { get { @@ -58,7 +58,9 @@ public struct ArrayAt { /// ### See Also /// - ``MLXArray/at`` /// - ``ArrayAtIndices`` - public subscript(indices: [MLXArrayIndex], stream stream: StreamOrDevice = .default) + public subscript(indices: some Sequence, + stream stream: StreamOrDevice = .default + ) -> ArrayAtIndices { get { @@ -99,7 +101,7 @@ public struct ArrayAtIndices { /// /// ### See Also /// - ``MLXArray/at`` - public func add(_ values: ScalarOrArray) -> MLXArray { + public func add(_ values: some ScalarOrArray) -> MLXArray { let values = values.asMLXArray(dtype: array.dtype) let (indices, update, axes) = scatterArguments( src: array, operations: indexOperations, update: values, stream: stream) @@ -128,7 +130,7 @@ public struct ArrayAtIndices { /// /// ### See Also /// - ``MLXArray/at`` - public func subtract(_ values: ScalarOrArray) -> MLXArray { + public func subtract(_ values: some ScalarOrArray) -> MLXArray { add(-values.asMLXArray(dtype: array.dtype)) } @@ -142,7 +144,7 @@ public struct ArrayAtIndices { /// /// ### See Also /// - ``MLXArray/at`` - public func multiply(_ values: ScalarOrArray) -> MLXArray { + public func multiply(_ values: some ScalarOrArray) -> MLXArray { let values = values.asMLXArray(dtype: array.dtype) let (indices, update, axes) = scatterArguments( src: array, operations: indexOperations, update: values, stream: stream) @@ -171,7 +173,7 @@ public struct ArrayAtIndices { /// /// ### See Also /// - ``MLXArray/at`` - public func divide(_ values: ScalarOrArray) -> MLXArray { + public func divide(_ values: some ScalarOrArray) -> MLXArray { multiply(values.asMLXArray(dtype: array.dtype).reciprocal()) } @@ -185,7 +187,7 @@ public struct ArrayAtIndices { /// /// ### See Also /// - ``MLXArray/at`` - public func minimum(_ values: ScalarOrArray) -> MLXArray { + public func minimum(_ values: some ScalarOrArray) -> MLXArray { let values = values.asMLXArray(dtype: array.dtype) let (indices, update, axes) = scatterArguments( src: array, operations: indexOperations, update: values, stream: stream) @@ -214,7 +216,7 @@ public struct ArrayAtIndices { /// /// ### See Also /// - ``MLXArray/at`` - public func maximum(_ values: ScalarOrArray) -> MLXArray { + public func maximum(_ values: some ScalarOrArray) -> MLXArray { let values = values.asMLXArray(dtype: array.dtype) let (indices, update, axes) = scatterArguments( src: array, operations: indexOperations, update: values, stream: stream) diff --git a/Source/MLX/Cmlx+Util.swift b/Source/MLX/Cmlx+Util.swift index 46e17d28..61854afa 100644 --- a/Source/MLX/Cmlx+Util.swift +++ b/Source/MLX/Cmlx+Util.swift @@ -4,7 +4,7 @@ import Cmlx import Foundation // return a +1 mlx_vector_array containing the given arrays -func new_mlx_vector_array(_ arrays: [MLXArray]) -> mlx_vector_array { +func new_mlx_vector_array(_ arrays: some Collection) -> mlx_vector_array { withExtendedLifetime(arrays) { mlx_vector_array_new_data(arrays.map { $0.ctx }, arrays.count) } diff --git a/Source/MLX/FFT.swift b/Source/MLX/FFT.swift index 45ef3f6c..04e4a425 100644 --- a/Source/MLX/FFT.swift +++ b/Source/MLX/FFT.swift @@ -62,7 +62,8 @@ public enum MLXFFT { /// ### See Also /// - public static func fft2( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [-2, -1], stream: StreamOrDevice = .default ) -> MLXArray { fftn(array, s: s, axes: axes, stream: stream) @@ -81,7 +82,8 @@ public enum MLXFFT { /// ### See Also /// - public static func ifft2( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [-2, -1], stream: StreamOrDevice = .default ) -> MLXArray { ifftn(array, s: s, axes: axes, stream: stream) @@ -100,7 +102,8 @@ public enum MLXFFT { /// ### See Also /// - public static func fftn( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [Int]?.none, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() if let s, let axes { @@ -146,7 +149,8 @@ public enum MLXFFT { /// ### See Also /// - public static func ifftn( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [Int]?.none, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() if let s, let axes { @@ -244,7 +248,8 @@ public enum MLXFFT { /// ### See Also /// - public static func rfft2( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [-2, -1], stream: StreamOrDevice = .default ) -> MLXArray { rfftn(array, s: s, axes: axes, stream: stream) @@ -268,7 +273,8 @@ public enum MLXFFT { /// ### See Also /// - public static func irfft2( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [-2, -1], stream: StreamOrDevice = .default ) -> MLXArray { irfftn(array, s: s, axes: axes, stream: stream) @@ -291,7 +297,8 @@ public enum MLXFFT { /// ### See Also /// - public static func rfftn( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [Int]?.none, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() if let s, let axes { @@ -342,7 +349,8 @@ public enum MLXFFT { /// ### See Also /// - public static func irfftn( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [Int]?.none, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() if let s, let axes { @@ -428,7 +436,9 @@ public func ifft( /// ### See Also /// - public func fft2( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], stream: StreamOrDevice = .default + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [-2, -1], + stream: StreamOrDevice = .default ) -> MLXArray { MLXFFT.fft2(array, s: s, axes: axes, stream: stream) } @@ -446,7 +456,9 @@ public func fft2( /// ### See Also /// - public func ifft2( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], stream: StreamOrDevice = .default + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [-2, -1], + stream: StreamOrDevice = .default ) -> MLXArray { MLXFFT.ifft2(array, s: s, axes: axes, stream: stream) } @@ -464,7 +476,8 @@ public func ifft2( /// ### See Also /// - public func fftn( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [Int]?.none, stream: StreamOrDevice = .default ) -> MLXArray { MLXFFT.fftn(array, s: s, axes: axes, stream: stream) } @@ -482,7 +495,8 @@ public func fftn( /// ### See Also /// - public func ifftn( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [Int]?.none, stream: StreamOrDevice = .default ) -> MLXArray { MLXFFT.ifftn(array, s: s, axes: axes, stream: stream) } @@ -546,7 +560,9 @@ public func irfft( /// ### See Also /// - public func rfft2( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], stream: StreamOrDevice = .default + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [-2, -1], + stream: StreamOrDevice = .default ) -> MLXArray { MLXFFT.rfft2(array, s: s, axes: axes, stream: stream) } @@ -569,7 +585,9 @@ public func rfft2( /// ### See Also /// - public func irfft2( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], stream: StreamOrDevice = .default + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [-2, -1], + stream: StreamOrDevice = .default ) -> MLXArray { MLXFFT.irfft2(array, s: s, axes: axes, stream: stream) } @@ -591,7 +609,8 @@ public func irfft2( /// ### See Also /// - public func rfftn( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [Int]?.none, stream: StreamOrDevice = .default ) -> MLXArray { MLXFFT.rfftn(array, s: s, axes: axes, stream: stream) } @@ -614,7 +633,8 @@ public func rfftn( /// ### See Also /// - public func irfftn( - _ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default + _ array: MLXArray, s: (some Collection)? = [Int]?.none, + axes: (some Collection)? = [Int]?.none, stream: StreamOrDevice = .default ) -> MLXArray { MLXFFT.irfftn(array, s: s, axes: axes, stream: stream) } diff --git a/Source/MLX/Factory.swift b/Source/MLX/Factory.swift index c6e63a46..f46c60a6 100644 --- a/Source/MLX/Factory.swift +++ b/Source/MLX/Factory.swift @@ -23,7 +23,7 @@ extension MLXArray { /// - ``zeros(like:stream:)`` /// - ``ones(_:type:stream:)`` static public func zeros( - _ shape: [Int], type: T.Type = Float.self, stream: StreamOrDevice = .default + _ shape: some Collection, type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { MLX.zeros(shape, type: type, stream: stream) } @@ -46,7 +46,7 @@ extension MLXArray { /// - ``zeros(like:stream:)`` /// - ``ones(_:type:stream:)`` static public func zeros( - _ shape: [Int], dtype: DType = .float32, stream: StreamOrDevice = .default + _ shape: some Collection, dtype: DType, stream: StreamOrDevice = .default ) -> MLXArray { MLX.zeros(shape, dtype: dtype, stream: stream) } @@ -90,7 +90,7 @@ extension MLXArray { /// - ``ones(like:stream:)`` /// - ``zeros(_:type:stream:)`` static public func ones( - _ shape: [Int], type: T.Type = Float.self, stream: StreamOrDevice = .default + _ shape: some Collection, type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { MLX.ones(shape, type: type, stream: stream) } @@ -113,7 +113,7 @@ extension MLXArray { /// - ``ones(like:stream:)`` /// - ``zeros(_:type:stream:)`` static public func ones( - _ shape: [Int], dtype: DType = .float32, stream: StreamOrDevice = .default + _ shape: some Collection, dtype: DType, stream: StreamOrDevice = .default ) -> MLXArray { MLX.ones(shape, dtype: dtype, stream: stream) } @@ -185,7 +185,7 @@ extension MLXArray { /// - /// - ``identity(_:type:stream:)`` static public func eye( - _ n: Int, m: Int? = nil, k: Int = 0, dtype: DType = .float32, + _ n: Int, m: Int? = nil, k: Int = 0, dtype: DType, stream: StreamOrDevice = .default ) -> MLXArray { MLX.eye(n, m: m, k: k, dtype: dtype, stream: stream) @@ -214,7 +214,8 @@ extension MLXArray { /// - ``full(_:values:stream:)`` /// - ``repeated(_:count:axis:stream:)`` static public func full( - _ shape: [Int], values: MLXArray, type: T.Type, stream: StreamOrDevice = .default + _ shape: some Collection, values: MLXArray, type: T.Type, + stream: StreamOrDevice = .default ) -> MLXArray { MLX.full(shape, values: values, type: type, stream: stream) } @@ -242,7 +243,8 @@ extension MLXArray { /// - ``full(_:values:stream:)`` /// - ``repeated(_:count:axis:stream:)`` static public func full( - _ shape: [Int], values: MLXArray, dtype: DType = .float32, stream: StreamOrDevice = .default + _ shape: some Collection, values: MLXArray, dtype: DType, + stream: StreamOrDevice = .default ) -> MLXArray { MLX.full(shape, values: values, dtype: dtype, stream: stream) } @@ -268,7 +270,9 @@ extension MLXArray { /// - /// - ``full(_:values:type:stream:)`` /// - ``repeated(_:count:axis:stream:)`` - static public func full(_ shape: [Int], values: MLXArray, stream: StreamOrDevice = .default) + static public func full( + _ shape: some Collection, values: MLXArray, stream: StreamOrDevice = .default + ) -> MLXArray { MLX.full(shape, values: values, stream: stream) @@ -315,7 +319,7 @@ extension MLXArray { /// - /// - ``eye(_:m:k:type:stream:)`` static public func identity( - _ n: Int, dtype: DType = .float32, stream: StreamOrDevice = .default + _ n: Int, dtype: DType, stream: StreamOrDevice = .default ) -> MLXArray { MLX.identity(n, dtype: dtype, stream: stream) } @@ -491,7 +495,7 @@ extension MLXArray { /// ### See Also /// - static public func tri( - _ n: Int, m: Int? = nil, k: Int = 0, dtype: DType = .float32, + _ n: Int, m: Int? = nil, k: Int = 0, dtype: DType, stream: StreamOrDevice = .default ) -> MLXArray { MLX.tri(n, m: m, k: k, dtype: dtype, stream: stream) @@ -517,7 +521,7 @@ extension MLXArray { /// - ``zeros(like:stream:)`` /// - ``ones(_:type:stream:)`` public func zeros( - _ shape: [Int], type: T.Type = Float.self, stream: StreamOrDevice = .default + _ shape: some Collection, type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_zeros(&result, shape.map { Int32($0) }, shape.count, T.dtype.cmlxDtype, stream.ctx) @@ -542,7 +546,7 @@ public func zeros( /// - ``zeros(like:stream:)`` /// - ``ones(_:type:stream:)`` public func zeros( - _ shape: [Int], dtype: DType = .float32, stream: StreamOrDevice = .default + _ shape: some Collection, dtype: DType, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_zeros(&result, shape.map { Int32($0) }, shape.count, dtype.cmlxDtype, stream.ctx) @@ -590,7 +594,7 @@ public func zeros(like array: MLXArray, stream: StreamOrDevice = .default) -> ML /// - ``ones(like:stream:)`` /// - ``zeros(_:type:stream:)`` public func ones( - _ shape: [Int], type: T.Type = Float.self, stream: StreamOrDevice = .default + _ shape: some Collection, type: T.Type = Float.self, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_ones(&result, shape.map { Int32($0) }, shape.count, T.dtype.cmlxDtype, stream.ctx) @@ -615,7 +619,7 @@ public func ones( /// - ``zeros(like:stream:)`` /// - ``ones(_:type:stream:)`` public func ones( - _ shape: [Int], dtype: DType = .float32, stream: StreamOrDevice = .default + _ shape: some Collection, dtype: DType, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_ones(&result, shape.map { Int32($0) }, shape.count, dtype.cmlxDtype, stream.ctx) @@ -693,7 +697,7 @@ public func eye( /// - /// - ``identity(_:type:stream:)`` public func eye( - _ n: Int, m: Int? = nil, k: Int = 0, dtype: DType = .float32, + _ n: Int, m: Int? = nil, k: Int = 0, dtype: DType, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -724,7 +728,8 @@ public func eye( /// - ``full(_:values:stream:)`` /// - ``repeated(_:count:axis:stream:)`` public func full( - _ shape: [Int], values: ScalarOrArray, type: T.Type, stream: StreamOrDevice = .default + _ shape: some Collection, values: some ScalarOrArray, type: T.Type, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() let values = values.asMLXArray(dtype: nil) @@ -755,7 +760,7 @@ public func full( /// - ``full(_:values:stream:)`` /// - ``repeated(_:count:axis:stream:)`` public func full( - _ shape: [Int], values: MLXArray, dtype: DType = .float32, + _ shape: some Collection, values: MLXArray, dtype: DType, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -784,7 +789,9 @@ public func full( /// - /// - ``full(_:values:type:stream:)`` /// - ``repeated(_:count:axis:stream:)`` -public func full(_ shape: [Int], values: ScalarOrArray, stream: StreamOrDevice = .default) +public func full( + _ shape: some Collection, values: some ScalarOrArray, stream: StreamOrDevice = .default +) -> MLXArray { var result = mlx_array_new() @@ -836,7 +843,7 @@ public func identity( /// - /// - ``eye(_:m:k:type:stream:)`` public func identity( - _ n: Int, dtype: DType = .float32, stream: StreamOrDevice = .default + _ n: Int, dtype: DType, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_identity(&result, n.int32, dtype.cmlxDtype, stream.ctx) @@ -1024,7 +1031,7 @@ public func tri( /// ### See Also /// - public func tri( - _ n: Int, m: Int? = nil, k: Int = 0, dtype: DType = .float32, + _ n: Int, m: Int? = nil, k: Int = 0, dtype: DType, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() diff --git a/Source/MLX/Linalg.swift b/Source/MLX/Linalg.swift index 4e30cc81..852bfb65 100644 --- a/Source/MLX/Linalg.swift +++ b/Source/MLX/Linalg.swift @@ -58,7 +58,7 @@ public enum MLXLinalg { /// ### See Also /// - ``norm(_:ord:axes:keepDims:stream:)`` public static func norm( - _ array: MLXArray, ord: NormKind? = nil, axes: [Int], keepDims: Bool = false, + _ array: MLXArray, ord: NormKind? = nil, axes: some Collection, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -113,7 +113,7 @@ public enum MLXLinalg { /// ### See Also /// - ``norm(_:ord:axes:keepDims:stream:)-8zljj`` public static func norm( - _ array: MLXArray, ord: Double, axes: [Int], keepDims: Bool = false, + _ array: MLXArray, ord: Double, axes: some Collection, keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -460,7 +460,8 @@ public enum MLXLinalg { /// ### See Also /// - ``norm(_:ord:axes:keepDims:stream:)`` public func norm( - _ array: MLXArray, ord: MLXLinalg.NormKind? = nil, axes: [Int], keepDims: Bool = false, + _ array: MLXArray, ord: MLXLinalg.NormKind? = nil, axes: some Collection, + keepDims: Bool = false, stream: StreamOrDevice = .default ) -> MLXArray { return MLXLinalg.norm(array, ord: ord, axes: axes, keepDims: keepDims, stream: stream) diff --git a/Source/MLX/MLXArray+Bytes.swift b/Source/MLX/MLXArray+Bytes.swift index 00163b87..5e0b0718 100644 --- a/Source/MLX/MLXArray+Bytes.swift +++ b/Source/MLX/MLXArray+Bytes.swift @@ -285,7 +285,7 @@ extension MLXArray { /// - /// - ``asArray(_:)`` /// - ``asData(access:)`` - public func asMTLBuffer(device: any MTLDevice, noCopy: Bool = false) -> (any MTLBuffer)? { + public func asMTLBuffer(device: some MTLDevice, noCopy: Bool = false) -> (any MTLBuffer)? { self.eval() if noCopy && self.contiguousToDimension() == 0 { @@ -304,7 +304,7 @@ extension MLXArray { } /// Return the strides for contiguous memory -func contiguousStrides(shape: [Int]) -> [Int] { +func contiguousStrides(shape: some Collection) -> [Int] { var result = [Int]() var current = 1 for d in shape.reversed() { diff --git a/Source/MLX/MLXArray+Indexing.swift b/Source/MLX/MLXArray+Indexing.swift index a56f3b29..bc86cc4e 100644 --- a/Source/MLX/MLXArray+Indexing.swift +++ b/Source/MLX/MLXArray+Indexing.swift @@ -80,8 +80,9 @@ extension MLXArray { } } - private func resolve(_ rangeExpression: any RangeExpression, _ axis: Int) -> (Int32, Int32) - { + private func resolve(_ rangeExpression: some RangeExpression, _ axis: Int) -> ( + Int32, Int32 + ) { func resolve(_ index: Int, _ axis: Int) -> Int32 { if index < 0 { return Int32(index + dim(axis)) @@ -184,7 +185,7 @@ extension MLXArray { /// ### See Also /// - @available(*, deprecated, message: "please use subscript(.ellipsis, 0 ..< 3) or equivalent") - public subscript(range: any RangeExpression, axis axis: Int, + public subscript(range: some RangeExpression, axis axis: Int, stream stream: StreamOrDevice = .default ) -> MLXArray { get { @@ -447,7 +448,7 @@ extension MLXArray { /// - ``MLXArrayIndex/newAxis`` /// - ``MLXArrayIndex/stride(from:to:by:)`` /// - ``MLXArray/at`` - public subscript(indices: MLXArrayIndex..., stream stream: StreamOrDevice = .default) + public subscript(indices: any MLXArrayIndex..., stream stream: StreamOrDevice = .default) -> MLXArray { get { @@ -462,7 +463,9 @@ extension MLXArray { /// General array indexing. /// /// See ``MLXArray/subscript(_:stream:)-375a0`` - public subscript(indices: [MLXArrayIndex], stream stream: StreamOrDevice = .default) + public subscript(indices: some Sequence, + stream stream: StreamOrDevice = .default + ) -> MLXArray { get { @@ -478,7 +481,7 @@ extension MLXArray { // MARK: - Support -func countNonNewAxisOperations(_ operations: any Sequence) -> Int { +func countNonNewAxisOperations(_ operations: some Sequence) -> Int { operations .filter { !$0.isNewAxis } .count diff --git a/Source/MLX/MLXArray+Init.swift b/Source/MLX/MLXArray+Init.swift index cade9f21..d5705783 100644 --- a/Source/MLX/MLXArray+Init.swift +++ b/Source/MLX/MLXArray+Init.swift @@ -4,14 +4,14 @@ import Cmlx import Foundation import Numerics -private func shapePrecondition(shape: [Int]?, count: Int) { +private func shapePrecondition(shape: (some Collection)?, count: Int) { if let shape { let total = shape.reduce(1, *) precondition(total == count, "shape \(shape) total \(total) != \(count) (actual)") } } -private func shapePrecondition(shape: [Int]?, byteCount: Int, type: DType) { +private func shapePrecondition(shape: (some Collection)?, byteCount: Int, type: DType) { if let shape { let total = shape.reduce(1, *) * type.size precondition(total == byteCount, "shape \(shape) total \(total)B != \(byteCount)B (actual)") @@ -272,13 +272,15 @@ extension MLXArray { /// /// ### See Also /// - - public convenience init(_ value: [T], _ shape: [Int]? = nil) { + public convenience init( + _ value: [T], _ shape: (some Collection)? = [Int]?.none + ) { shapePrecondition(shape: shape, count: value.count) self.init( value.withUnsafeBufferPointer { ptr in - let shape = shape ?? [value.count] + let shape = shape?.asInt32 ?? [value.count.int32] return mlx_array_new_data( - ptr.baseAddress!, shape.asInt32, shape.count.int32, T.dtype.cmlxDtype) + ptr.baseAddress!, shape, shape.count.int32, T.dtype.cmlxDtype) }) } @@ -295,7 +297,7 @@ extension MLXArray { /// ### See Also /// - /// - ``init(int64:_:)-7bgj2`` - public convenience init(_ value: [Int], _ shape: [Int]? = nil) { + public convenience init(_ value: [Int], _ shape: (some Collection)? = [Int]?.none) { shapePrecondition(shape: shape, count: value.count) precondition( value.allSatisfy { (Int(Int32.min) ... Int(Int32.max)).contains($0) }, @@ -306,9 +308,9 @@ extension MLXArray { value .map { Int32($0) } .withUnsafeBufferPointer { ptr in - let shape = shape ?? [value.count] + let shape = shape?.asInt32 ?? [value.count.int32] return mlx_array_new_data( - ptr.baseAddress!, shape.asInt32, shape.count.int32, Int32.dtype.cmlxDtype) + ptr.baseAddress!, shape, shape.count.int32, Int32.dtype.cmlxDtype) }) } @@ -323,15 +325,15 @@ extension MLXArray { /// /// ### See Also /// - - public convenience init(int64 value: [Int], _ shape: [Int]? = nil) { + public convenience init(int64 value: [Int], _ shape: (some Collection)? = [Int]?.none) { shapePrecondition(shape: shape, count: value.count) self.init( value .withUnsafeBufferPointer { ptr in - let shape = shape ?? [value.count] + let shape = shape?.asInt32 ?? [value.count.int32] return mlx_array_new_data( - ptr.baseAddress!, shape.asInt32, shape.count.int32, Int.dtype.cmlxDtype) + ptr.baseAddress!, shape, shape.count.int32, Int.dtype.cmlxDtype) }) } @@ -346,14 +348,16 @@ extension MLXArray { /// /// ### See Also /// - - public convenience init(converting value: [Double], _ shape: [Int]? = nil) { + public convenience init( + converting value: [Double], _ shape: (some Collection)? = [Int]?.none + ) { shapePrecondition(shape: shape, count: value.count) let floats = value.map { Float($0) } self.init( floats.withUnsafeBufferPointer { ptr in - let shape = shape ?? [floats.count] + let shape = shape?.asInt32 ?? [floats.count.int32] return mlx_array_new_data( - ptr.baseAddress!, shape.asInt32, shape.count.int32, Float.dtype.cmlxDtype) + ptr.baseAddress!, shape, shape.count.int32, Float.dtype.cmlxDtype) }) } @@ -362,7 +366,7 @@ extension MLXArray { *, unavailable, renamed: "MLXArray(converting:shape:)", message: "Use MLXArray(converting: [1.0, 2.0, ...]) instead" ) - public convenience init(_ value: [Double], _ shape: [Int]? = nil) { + public convenience init(_ value: [Double], _ shape: (some Collection)? = [Int]?.none) { fatalError("unavailable") } @@ -379,7 +383,9 @@ extension MLXArray { /// /// ### See Also /// - - public convenience init(_ sequence: S, _ shape: [Int]? = nil) + public convenience init( + _ sequence: S, _ shape: (some Collection)? = [Int]?.none + ) where S.Element: HasDType { let value = Array(sequence) if S.Element.self == Int.self { @@ -402,14 +408,16 @@ extension MLXArray { /// /// ### See Also /// - - public convenience init(int64 sequence: any Sequence, _ shape: [Int]? = nil) { + public convenience init( + int64 sequence: some Sequence, _ shape: (some Collection)? = [Int]?.none + ) { let value = Array(sequence) shapePrecondition(shape: shape, count: value.count) self.init( value.withUnsafeBufferPointer { ptr in - let shape = shape ?? [value.count] + let shape = shape?.asInt32 ?? [value.count.int32] return mlx_array_new_data( - ptr.baseAddress!, shape.asInt32, shape.count.int32, Int.dtype.cmlxDtype) + ptr.baseAddress!, shape, shape.count.int32, Int.dtype.cmlxDtype) }) } @@ -425,12 +433,14 @@ extension MLXArray { /// /// ### See Also /// - - public convenience init(_ ptr: UnsafeBufferPointer, _ shape: [Int]? = nil) { + public convenience init( + _ ptr: UnsafeBufferPointer, _ shape: (some Collection)? = [Int]?.none + ) { shapePrecondition(shape: shape, count: ptr.count) - let shape = shape ?? [ptr.count] + let shape = shape?.asInt32 ?? [ptr.count.int32] self.init( mlx_array_new_data( - ptr.baseAddress!, shape.asInt32, shape.count.int32, T.dtype.cmlxDtype)) + ptr.baseAddress!, shape, shape.count.int32, T.dtype.cmlxDtype)) } /// Initializer allowing creation of `MLXArray` from a `UnsafeRawBufferPointer` filled @@ -444,7 +454,7 @@ extension MLXArray { /// ### See Also /// - public convenience init( - _ ptr: UnsafeRawBufferPointer, _ shape: [Int]? = nil, type: T.Type + _ ptr: UnsafeRawBufferPointer, _ shape: (some Collection)? = [Int]?.none, type: T.Type ) { let buffer = ptr.assumingMemoryBound(to: type) self.init(buffer, shape) @@ -460,14 +470,16 @@ extension MLXArray { /// /// ### See Also /// - - public convenience init(_ data: Data, _ shape: [Int]? = nil, type: T.Type) { + public convenience init( + _ data: Data, _ shape: (some Collection)? = [Int]?.none, type: T.Type + ) { self.init( data.withUnsafeBytes { ptr in let buffer = ptr.assumingMemoryBound(to: type) shapePrecondition(shape: shape, count: buffer.count) - let shape = shape ?? [buffer.count] + let shape = shape?.asInt32 ?? [buffer.count.int32] return mlx_array_new_data( - ptr.baseAddress!, shape.asInt32, shape.count.int32, T.dtype.cmlxDtype) + ptr.baseAddress!, shape, shape.count.int32, T.dtype.cmlxDtype) }) } @@ -475,14 +487,16 @@ extension MLXArray { /// an optional shape and an explicit DType. /// ### See Also /// - - public convenience init(_ data: Data, _ shape: [Int]? = nil, dtype: DType) { + public convenience init( + _ data: Data, _ shape: (some Collection)? = [Int]?.none, dtype: DType + ) { self.init( data.withUnsafeBytes { ptr in shapePrecondition(shape: shape, byteCount: data.count, type: dtype) precondition(data.count % dtype.size == 0) - let shape = shape ?? [data.count / dtype.size] + let shape = shape?.asInt32 ?? [Int32(data.count / dtype.size)] return mlx_array_new_data( - ptr.baseAddress!, shape.asInt32, shape.count.int32, dtype.cmlxDtype) + ptr.baseAddress!, shape, shape.count.int32, dtype.cmlxDtype) }) } diff --git a/Source/MLX/MLXArray+Ops.swift b/Source/MLX/MLXArray+Ops.swift index f5bbaece..f1aa12d4 100644 --- a/Source/MLX/MLXArray+Ops.swift +++ b/Source/MLX/MLXArray+Ops.swift @@ -1209,7 +1209,9 @@ extension MLXArray { /// - ``all(axis:keepDims:stream:)`` /// - ``all(keepDims:stream:)`` /// - ``all(_:axes:keepDims:stream:)`` - public func all(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) + public func all( + axes: some Collection, keepDims: Bool = false, stream: StreamOrDevice = .default + ) -> MLXArray { var result = mlx_array_new() @@ -1321,7 +1323,9 @@ extension MLXArray { /// - ``any(axis:keepDims:stream:)`` /// - ``any(keepDims:stream:)`` /// - ``any(_:axes:keepDims:stream:)`` - public func any(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) + public func any( + axes: some Collection, keepDims: Bool = false, stream: StreamOrDevice = .default + ) -> MLXArray { var result = mlx_array_new() @@ -1754,7 +1758,7 @@ extension MLXArray { /// ### See Also /// - /// - ``expandedDimensions(axis:stream:)`` - public func expandedDimensions(axes: [Int], stream: StreamOrDevice = .default) + public func expandedDimensions(axes: some Collection, stream: StreamOrDevice = .default) -> MLXArray { var result = mlx_array_new() @@ -1925,7 +1929,9 @@ extension MLXArray { /// - ``logSumExp(axis:keepDims:stream:)`` /// - ``logSumExp(keepDims:stream:)`` /// - ``logSumExp(_:axes:keepDims:stream:)`` - public func logSumExp(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) + public func logSumExp( + axes: some Collection, keepDims: Bool = false, stream: StreamOrDevice = .default + ) -> MLXArray { var result = mlx_array_new() @@ -2035,7 +2041,9 @@ extension MLXArray { /// - ``max(axis:keepDims:stream:)`` /// - ``max(keepDims:stream:)`` /// - ``max(_:axes:keepDims:stream:)`` - public func max(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) + public func max( + axes: some Collection, keepDims: Bool = false, stream: StreamOrDevice = .default + ) -> MLXArray { var result = mlx_array_new() @@ -2113,7 +2121,9 @@ extension MLXArray { /// - ``mean(axis:keepDims:stream:)`` /// - ``mean(keepDims:stream:)`` /// - ``mean(_:axes:keepDims:stream:)`` - public func mean(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) + public func mean( + axes: some Collection, keepDims: Bool = false, stream: StreamOrDevice = .default + ) -> MLXArray { var result = mlx_array_new() @@ -2191,7 +2201,9 @@ extension MLXArray { /// - ``min(axis:keepDims:stream:)`` /// - ``min(keepDims:stream:)`` /// - ``min(_:axes:keepDims:stream:)`` - public func min(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) + public func min( + axes: some Collection, keepDims: Bool = false, stream: StreamOrDevice = .default + ) -> MLXArray { var result = mlx_array_new() @@ -2333,7 +2345,9 @@ extension MLXArray { /// - ``product(axis:keepDims:stream:)`` /// - ``product(keepDims:stream:)`` /// - ``product(_:axes:keepDims:stream:)`` - public func product(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) + public func product( + axes: some Collection, keepDims: Bool = false, stream: StreamOrDevice = .default + ) -> MLXArray { var result = mlx_array_new() @@ -2415,7 +2429,9 @@ extension MLXArray { /// - /// - ``reshaped(_:stream:)`` /// - ``reshaped(_:_:stream:)-96lgr`` - public func reshaped(_ newShape: [Int], stream: StreamOrDevice = .default) -> MLXArray { + public func reshaped(_ newShape: some Collection, stream: StreamOrDevice = .default) + -> MLXArray + { var result = mlx_array_new() mlx_reshape(&result, ctx, newShape.asInt32, newShape.count, stream.ctx) return MLXArray(result) @@ -2550,7 +2566,9 @@ extension MLXArray { /// - /// - ``split(parts:axis:stream:)`` /// - ``split(_:indices:axis:stream:)`` - public func split(indices: [Int], axis: Int = 0, stream: StreamOrDevice = .default) + public func split( + indices: some Collection, axis: Int = 0, stream: StreamOrDevice = .default + ) -> [MLXArray] { var vec = mlx_vector_array_new() @@ -2591,7 +2609,8 @@ extension MLXArray { /// - ``squeezed(axis:stream:)`` /// - ``squeezed(stream:)`` /// - ``squeezed(_:axes:stream:)`` - public func squeezed(axes: [Int], stream: StreamOrDevice = .default) -> MLXArray { + public func squeezed(axes: some Collection, stream: StreamOrDevice = .default) -> MLXArray + { var result = mlx_array_new() mlx_squeeze_axes(&result, ctx, axes.asInt32, axes.count, stream.ctx) return MLXArray(result) @@ -2638,7 +2657,9 @@ extension MLXArray { /// - ``sum(axis:keepDims:stream:)`` /// - ``sum(keepDims:stream:)`` /// - ``sum(_:axes:keepDims:stream:)`` - public func sum(axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default) + public func sum( + axes: some Collection, keepDims: Bool = false, stream: StreamOrDevice = .default + ) -> MLXArray { var result = mlx_array_new() @@ -2763,7 +2784,9 @@ extension MLXArray { /// - ``transposed(axis:stream:)`` /// - ``transposed(stream:)`` /// - ``transposed(_:axes:stream:)`` - public func transposed(axes: [Int], stream: StreamOrDevice = .default) -> MLXArray { + public func transposed(axes: some Collection, stream: StreamOrDevice = .default) + -> MLXArray + { var result = mlx_array_new() mlx_transpose_axes(&result, ctx, axes.asInt32, axes.count, stream.ctx) return MLXArray(result) @@ -2825,7 +2848,8 @@ extension MLXArray { /// - ``variance(keepDims:ddof:stream:)`` /// - ``variance(_:axes:keepDims:ddof:stream:)`` public func variance( - axes: [Int], keepDims: Bool = false, ddof: Int = 0, stream: StreamOrDevice = .default + axes: some Collection, keepDims: Bool = false, ddof: Int = 0, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_var_axes(&result, ctx, axes.asInt32, axes.count, keepDims, ddof.int32, stream.ctx) diff --git a/Source/MLX/MLXFastKernel.swift b/Source/MLX/MLXFastKernel.swift index bed52d15..a58be428 100644 --- a/Source/MLX/MLXFastKernel.swift +++ b/Source/MLX/MLXFastKernel.swift @@ -45,12 +45,12 @@ extension MLXFast { public let outputNames: [String] init( - name: String, inputNames: [String], outputNames: [String], + name: String, inputNames: some Sequence, outputNames: some Sequence, source: String, header: String = "", ensureRowContiguous: Bool = true, atomicOutputs: Bool = false ) { - self.outputNames = outputNames + self.outputNames = Array(outputNames) let input_names = mlx_vector_string_new() defer { mlx_vector_string_free(input_names) } @@ -60,7 +60,7 @@ extension MLXFast { let output_names = mlx_vector_string_new() defer { mlx_vector_string_free(output_names) } - for name in outputNames { + for name in self.outputNames { mlx_vector_string_append_value(output_names, name) } @@ -94,12 +94,12 @@ extension MLXFast { /// - stream: stream to run on /// - Returns: array of `MLXArray` public func callAsFunction( - _ inputs: [ScalarOrArray], - template: [(String, KernelTemplateArg)]? = nil, + _ inputs: [any ScalarOrArray], + template: [(String, any KernelTemplateArg)]? = nil, grid: (Int, Int, Int), threadGroup: (Int, Int, Int), - outputShapes: [[Int]], - outputDTypes: [DType], + outputShapes: some Sequence<[Int]>, + outputDTypes: some Sequence, initValue: Float? = nil, verbose: Bool = false, stream: StreamOrDevice = .default @@ -171,7 +171,7 @@ extension MLXFast { /// e.g. `device atomic` /// - Returns: an ``MLXFastKernel`` -- see that for information on how to call it public static func metalKernel( - name: String, inputNames: [String], outputNames: [String], + name: String, inputNames: some Sequence, outputNames: some Sequence, source: String, header: String = "", ensureRowContiguous: Bool = true, atomicOutputs: Bool = false diff --git a/Source/MLX/Nested.swift b/Source/MLX/Nested.swift index 8bd35df6..26f7bed7 100644 --- a/Source/MLX/Nested.swift +++ b/Source/MLX/Nested.swift @@ -465,7 +465,9 @@ public indirect enum NestedItem: IndentedDescription { /// - ``NestedDictionary/flattened(prefix:)`` /// - ``NestedDictionary/unflattened(_:)-4p8bn`` /// - ``NestedDictionary/unflattened(_:)-7xuiv`` - public static func unflattened(_ tree: [(Key, Element)]) -> NestedItem + public static func unflattened(_ tree: some Collection<(Key, Element)>) -> NestedItem< + Key, Element + > where Key == String { if tree.isEmpty { return .dictionary([:]) @@ -489,11 +491,13 @@ public indirect enum NestedItem: IndentedDescription { } } - private static func unflattenedRecurse(_ tree: [(String, Element)]) -> NestedItem< - String, Element - > { - if tree.count == 1 && tree[0].0 == "" { - return .value(tree[0].1) + private static func unflattenedRecurse(_ tree: some Collection<(String, Element)>) + -> NestedItem< + String, Element + > + { + if tree.count == 1, let first = tree.first, first.0 == "" { + return .value(first.1) } var children = [String: [(String, Element)]]() @@ -510,7 +514,7 @@ public indirect enum NestedItem: IndentedDescription { children[String(current), default: []].append((String(next), value)) } - switch UnflattenKind.detect(key: tree[0].0) { + switch UnflattenKind.detect(key: tree.first!.0) { case .list: if children.isEmpty { return .array([]) @@ -565,7 +569,9 @@ public indirect enum NestedItem: IndentedDescription { } } - func replacingValues(with values: [Element], index: Int) -> (Int, NestedItem) { + func replacingValues>(with values: Values, index: Int) -> ( + Int, NestedItem + ) where Values.Index == Int { switch self { case .none: return (index, .none) @@ -951,7 +957,9 @@ public struct NestedDictionary: CustomStringConvertible /// ### See Also /// - ``flattened(prefix:)`` /// - ``unflattened(_:)-7xuiv`` - static public func unflattened(_ flat: [(Key, Element)]) -> NestedDictionary + static public func unflattened(_ flat: some Collection<(Key, Element)>) -> NestedDictionary< + String, Element + > where Key == String { switch NestedItem.unflattened(flat) { case .dictionary(let values): @@ -987,7 +995,9 @@ public struct NestedDictionary: CustomStringConvertible /// /// ### See Also /// - ``flattenedValues()`` - public func replacingValues(with values: [Element]) -> NestedDictionary { + public func replacingValues>(with values: Values) + -> NestedDictionary where Values.Index == Int + { switch asItem().replacingValues(with: values, index: 0) { case (_, .dictionary(let values)): return NestedDictionary(values: values) diff --git a/Source/MLX/Ops+Array.swift b/Source/MLX/Ops+Array.swift index c53c331d..1332ba49 100644 --- a/Source/MLX/Ops+Array.swift +++ b/Source/MLX/Ops+Array.swift @@ -49,7 +49,8 @@ public func abs(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArra /// - ``all(_:keepDims:stream:)`` /// - ``MLXArray/all(axes:keepDims:stream:)`` public func all( - _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default + _ array: MLXArray, axes: some Collection, keepDims: Bool = false, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_all_axes(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) @@ -171,7 +172,8 @@ public func allClose( /// - ``any(_:keepDims:stream:)`` /// - ``MLXArray/any(axes:keepDims:stream:)`` public func any( - _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default + _ array: MLXArray, axes: some Collection, keepDims: Bool = false, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_any_axes(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) @@ -845,7 +847,8 @@ public func log1p(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXAr /// - ``logSumExp(_:keepDims:stream:)`` /// - ``MLXArray/logSumExp(axes:keepDims:stream:)`` public func logSumExp( - _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default + _ array: MLXArray, axes: some Collection, keepDims: Bool = false, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_logsumexp_axes(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) @@ -968,7 +971,8 @@ public func matmul(_ a: MLXArray, _ b: MLXArray, stream: StreamOrDevice = .defau /// - ``max(_:keepDims:stream:)`` /// - ``MLXArray/max(axes:keepDims:stream:)`` public func max( - _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default + _ array: MLXArray, axes: some Collection, keepDims: Bool = false, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_max_axes(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) @@ -1051,7 +1055,8 @@ public func max(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevic /// - ``mean(_:keepDims:stream:)`` /// - ``MLXArray/mean(axes:keepDims:stream:)`` public func mean( - _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default + _ array: MLXArray, axes: some Collection, keepDims: Bool = false, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_mean_axes(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) @@ -1134,7 +1139,8 @@ public func mean(_ array: MLXArray, keepDims: Bool = false, stream: StreamOrDevi /// - ``min(_:keepDims:stream:)`` /// - ``MLXArray/min(axes:keepDims:stream:)`` public func min( - _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default + _ array: MLXArray, axes: some Collection, keepDims: Bool = false, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_min_axes(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) @@ -1294,7 +1300,8 @@ public func pow(_ array: T, _ other: MLXArray, stream: StreamO /// - ``product(_:keepDims:stream:)`` /// - ``MLXArray/product(axes:keepDims:stream:)`` public func product( - _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default + _ array: MLXArray, axes: some Collection, keepDims: Bool = false, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_prod_axes(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) @@ -1379,7 +1386,9 @@ public func reciprocal(_ array: MLXArray, stream: StreamOrDevice = .default) -> /// - /// - ``MLXArray/reshaped(_:stream:)-19x5z`` /// - ``reshaped(_:_:stream:)-96lgr`` -public func reshaped(_ array: MLXArray, _ newShape: [Int], stream: StreamOrDevice = .default) +public func reshaped( + _ array: MLXArray, _ newShape: some Collection, stream: StreamOrDevice = .default +) -> MLXArray { var result = mlx_array_new() @@ -1546,7 +1555,8 @@ public func split(_ array: MLXArray, axis: Int = 0, stream: StreamOrDevice = .de /// - ``split(_:parts:axis:stream:)`` /// - ``MLXArray/split(indices:axis:stream:)`` public func split( - _ array: MLXArray, indices: [Int], axis: Int = 0, stream: StreamOrDevice = .default + _ array: MLXArray, indices: some Collection, axis: Int = 0, + stream: StreamOrDevice = .default ) -> [MLXArray] { var vec = mlx_vector_array_new() mlx_split_sections(&vec, array.ctx, indices.asInt32, indices.count, axis.int32, stream.ctx) @@ -1587,8 +1597,9 @@ public func square(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXA /// - ``squeezed(_:axis:stream:)`` /// - ``squeezed(_:stream:)`` /// - ``MLXArray/squeezed(axes:stream:)`` -public func squeezed(_ array: MLXArray, axes: [Int], stream: StreamOrDevice = .default) -> MLXArray -{ +public func squeezed( + _ array: MLXArray, axes: some Collection, stream: StreamOrDevice = .default +) -> MLXArray { var result = mlx_array_new() mlx_squeeze_axes(&result, array.ctx, axes.asInt32, axes.count, stream.ctx) return MLXArray(result) @@ -1638,7 +1649,8 @@ public func squeezed(_ array: MLXArray, stream: StreamOrDevice = .default) -> ML /// - ``sum(_:keepDims:stream:)`` /// - ``MLXArray/sum(axes:keepDims:stream:)`` public func sum( - _ array: MLXArray, axes: [Int], keepDims: Bool = false, stream: StreamOrDevice = .default + _ array: MLXArray, axes: some Collection, keepDims: Bool = false, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_sum_axes(&result, array.ctx, axes.asInt32, axes.count, keepDims, stream.ctx) @@ -1770,7 +1782,9 @@ public func take(_ array: MLXArray, _ indices: MLXArray, stream: StreamOrDevice /// - ``transposed(_:axis:stream:)`` /// - ``transposed(_:stream:)`` /// - ``MLXArray/transposed(axes:stream:)`` -public func transposed(_ array: MLXArray, axes: [Int], stream: StreamOrDevice = .default) +public func transposed( + _ array: MLXArray, axes: some Collection, stream: StreamOrDevice = .default +) -> MLXArray { var result = mlx_array_new() @@ -1838,7 +1852,7 @@ public func T(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArray /// - ``variance(_:keepDims:ddof:stream:)`` /// - ``MLXArray/variance(axes:keepDims:ddof:stream:)`` public func variance( - _ array: MLXArray, axes: [Int], keepDims: Bool = false, ddof: Int = 0, + _ array: MLXArray, axes: some Collection, keepDims: Bool = false, ddof: Int = 0, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() diff --git a/Source/MLX/Ops.swift b/Source/MLX/Ops.swift index 3191edf1..1cc41550 100644 --- a/Source/MLX/Ops.swift +++ b/Source/MLX/Ops.swift @@ -6,7 +6,7 @@ import Foundation // MARK: - Internal Ops /// Broadcast a vector of arrays against one another. -func broadcast(arrays: [MLXArray], stream: StreamOrDevice = .default) -> [MLXArray] { +func broadcast(arrays: some Collection, stream: StreamOrDevice = .default) -> [MLXArray] { let vector_array = new_mlx_vector_array(arrays) defer { mlx_vector_array_free(vector_array) } @@ -359,10 +359,11 @@ public func argSort(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLX /// ### See Also /// - public func asStrided( - _ array: MLXArray, _ shape: [Int]? = nil, strides: [Int]? = nil, offset: Int = 0, + _ array: MLXArray, _ shape: (some Collection)? = [Int]?.none, + strides: (some Collection)? = [Int]?.none, offset: Int = 0, stream: StreamOrDevice = .default ) -> MLXArray { - let shape = shape ?? array.shape + let shape = shape.map { Array($0) } ?? array.shape let resolvedStrides: [Int64] if let strides { @@ -439,7 +440,9 @@ public func blockMaskedMM( /// /// ### See Also /// - -public func broadcast(_ array: MLXArray, to shape: [Int], stream: StreamOrDevice = .default) +public func broadcast( + _ array: MLXArray, to shape: some Collection, stream: StreamOrDevice = .default +) -> MLXArray { var result = mlx_array_new() @@ -536,7 +539,9 @@ public func clip(_ array: MLXArray, max: A, stream: StreamOrDe /// /// ### See Also /// - -public func concatenated(_ arrays: [MLXArray], axis: Int = 0, stream: StreamOrDevice = .default) +public func concatenated( + _ arrays: some Collection, axis: Int = 0, stream: StreamOrDevice = .default +) -> MLXArray { let vector_array = new_mlx_vector_array(arrays) @@ -682,13 +687,13 @@ public func conv3d( /// - Parameters: /// - array: Input array of shape `(N, ..., C_in)` /// - weight: Weight array of shape `(C_out, ..., C_in)` -/// - strides: `Int` or `[Int]` with kernel strides. All dimensions get the +/// - strides: `Int` or `some Collection` with kernel strides. All dimensions get the /// same stride if only one number is specified. -/// - padding: `Int` or `[Int]` with input padding. All dimensions get the +/// - padding: `Int` or `some Collection` with input padding. All dimensions get the /// same padding if only one number is specified. -/// - kernelDilation: `Int` or `[Int]` with kernel dilation. All dimensions get the +/// - kernelDilation: `Int` or `some Collection` with kernel dilation. All dimensions get the /// same dilation if only one number is specified. -/// - inputDilation: `Int` or `[Int]` with input dilation. All dimensions get the +/// - inputDilation: `Int` or `some Collection` with input dilation. All dimensions get the /// same dilation if only one number is specified. /// - groups: input feature groups /// - flip: Flip the order in which the spatial dimensions of the weights are processed. @@ -729,12 +734,12 @@ public func convGeneral( /// - Parameters: /// - array: Input array of shape `(N, ..., C_in)` /// - weight: Weight array of shape `(C_out, ..., C_in)` -/// - strides: `Int` or `[Int]` with kernel strides. All dimensions get the +/// - strides: `Int` or `some Collection` with kernel strides. All dimensions get the /// same stride if only one number is specified. /// - padding: pair of padding values to apply to all dimensions -/// - kernelDilation: `Int` or `[Int]` with kernel dilation. All dimensions get the +/// - kernelDilation: `Int` or `some Collection` with kernel dilation. All dimensions get the /// same dilation if only one number is specified. -/// - inputDilation: `Int` or `[Int]` with input dilation. All dimensions get the +/// - inputDilation: `Int` or `some Collection` with input dilation. All dimensions get the /// same dilation if only one number is specified. /// - groups: input feature groups /// - flip: Flip the order in which the spatial dimensions of the weights are processed. @@ -1142,7 +1147,9 @@ public func einsum(_ subscripts: String, _ operands: MLXArray..., stream: Stream /// - subscripts: Einstein summation convention equation /// - operands: input arrays /// - stream: stream or device to evaluate on -public func einsum(_ subscripts: String, operands: [MLXArray], stream: StreamOrDevice = .default) +public func einsum( + _ subscripts: String, operands: some Collection, stream: StreamOrDevice = .default +) -> MLXArray { let operands = new_mlx_vector_array(operands) @@ -1230,7 +1237,9 @@ public func erfInverse(_ array: MLXArray, stream: StreamOrDevice = .default) -> /// ### See Also /// - /// - ``expandedDimensions(_:axis:stream:)`` -public func expandedDimensions(_ array: MLXArray, axes: [Int], stream: StreamOrDevice = .default) +public func expandedDimensions( + _ array: MLXArray, axes: some Collection, stream: StreamOrDevice = .default +) -> MLXArray { var result = mlx_array_new() @@ -1748,7 +1757,7 @@ public enum MeshGridIndexing: String, Sendable { /// - indexing: indexing mode /// - stream: stream or device to evaluate on public func meshGrid( - _ arrays: [MLXArray], sparse: Bool = false, indexing: MeshGridIndexing = .xy, + _ arrays: some Collection, sparse: Bool = false, indexing: MeshGridIndexing = .xy, stream: StreamOrDevice = .default ) -> [MLXArray] { let mlxArrays = new_mlx_vector_array(arrays) @@ -1985,7 +1994,8 @@ public func padded( /// - /// - ``padded(_:width:mode:value:stream:)`` public func padded( - _ array: MLXArray, widths: [IntOrPair], mode: PadMode = .constant, value: MLXArray? = nil, + _ array: MLXArray, widths: some Collection, mode: PadMode = .constant, + value: MLXArray? = nil, stream: StreamOrDevice = .default ) -> MLXArray { let ndim = array.ndim @@ -2252,7 +2262,10 @@ public func roll(_ a: MLXArray, shift: Int, axis: Int, stream: StreamOrDevice = /// /// ### See Also /// - -public func roll(_ a: MLXArray, shift: Int, axes: [Int]? = nil, stream: StreamOrDevice = .default) +public func roll( + _ a: MLXArray, shift: Int, axes: (some Collection)? = [Int]?.none, + stream: StreamOrDevice = .default +) -> MLXArray { var result = mlx_array_new() @@ -2313,7 +2326,8 @@ public func sinh(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXArr @available(*, deprecated, renamed: "softmax(_:axes:precise:stream:)") @_documentation(visibility: internal) public func softMax( - _ array: MLXArray, axes: [Int], precise: Bool = false, stream: StreamOrDevice = .default + _ array: MLXArray, axes: some Collection, precise: Bool = false, + stream: StreamOrDevice = .default ) -> MLXArray { softmax(array, axes: axes, precise: precise, stream: stream) } @@ -2335,7 +2349,8 @@ public func softMax( /// - ``softmax(_:axis:precise:stream:)`` /// - ``softmax(_:precise:stream:)`` public func softmax( - _ array: MLXArray, axes: [Int], precise: Bool = false, stream: StreamOrDevice = .default + _ array: MLXArray, axes: some Collection, precise: Bool = false, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_softmax_axes(&result, array.ctx, axes.asInt32, axes.count, precise, stream.ctx) @@ -2452,7 +2467,7 @@ public func sorted(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLXA /// - ``std(_:axis:keepDims:ddof:stream:)`` /// - ``std(_:keepDims:ddof:stream:)`` public func std( - _ array: MLXArray, axes: [Int], keepDims: Bool = false, ddof: Int = 0, + _ array: MLXArray, axes: some Collection, keepDims: Bool = false, ddof: Int = 0, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -2506,7 +2521,9 @@ public func std( /// /// ### See Also /// - -public func stacked(_ arrays: [MLXArray], axis: Int = 0, stream: StreamOrDevice = .default) +public func stacked( + _ arrays: some Collection, axis: Int = 0, stream: StreamOrDevice = .default +) -> MLXArray { let vector_array = new_mlx_vector_array(arrays) @@ -2681,7 +2698,8 @@ public func tensordot( /// - /// - ``tensordot(_:_:axes:stream:)`` public func tensordot( - _ a: MLXArray, _ b: MLXArray, axes: ([Int], [Int]), stream: StreamOrDevice = .default + _ a: MLXArray, _ b: MLXArray, axes: (some Collection, some Collection), + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_tensordot( @@ -2702,7 +2720,9 @@ public func tensordot( /// ### See Also /// - /// - ``tiled(_:repetitions:stream:)-eouf`` -public func tiled(_ array: MLXArray, repetitions: [Int], stream: StreamOrDevice = .default) +public func tiled( + _ array: MLXArray, repetitions: some Collection, stream: StreamOrDevice = .default +) -> MLXArray { var result = mlx_array_new() @@ -2913,7 +2933,7 @@ public func flatten( /// - shape: shape to unflatten into /// - stream: stream or device to evaluate on public func unflatten( - _ a: MLXArray, axis: Int, shape: [Int], stream: StreamOrDevice = .default + _ a: MLXArray, axis: Int, shape: some Collection, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_unflatten(&result, a.ctx, axis.int32, shape.map { Int32($0) }, shape.count, stream.ctx) diff --git a/Source/MLX/ParameterTypes.swift b/Source/MLX/ParameterTypes.swift index 08d0080a..e8a7317b 100644 --- a/Source/MLX/ParameterTypes.swift +++ b/Source/MLX/ParameterTypes.swift @@ -27,9 +27,9 @@ public struct IntOrPair: ExpressibleByIntegerLiteral, ExpressibleByArrayLiteral, self.values = (elements[0], elements[1]) } - public init(_ values: [Int]) { + public init(_ values: some Collection) { precondition(values.count == 2) - self.values = (values[0], values[1]) + self.values = (values.first!, values[values.index(after: values.startIndex)]) } public init(_ values: (Int, Int)) { @@ -67,9 +67,12 @@ public struct IntOrTriple: ExpressibleByIntegerLiteral, ExpressibleByArrayLitera self.values = (elements[0], elements[1], elements[2]) } - public init(_ values: [Int]) { + public init(_ values: some Collection) { precondition(values.count == 3) - self.values = (values[0], values[1], values[2]) + self.values = ( + values.first!, values[values.index(after: values.startIndex)], + values[values.index(values.startIndex, offsetBy: 2)] + ) } public init(_ values: (Int, Int, Int)) { diff --git a/Source/MLX/Random.swift b/Source/MLX/Random.swift index 786f206b..a6294167 100644 --- a/Source/MLX/Random.swift +++ b/Source/MLX/Random.swift @@ -148,8 +148,8 @@ public enum MLXRandom { /// let array = MLXRandom.uniform(0.5 ..< 1, [50], key: key) /// ``` public static func uniform( - _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, - key: RandomStateOrKey? = nil, + _ range: Range, _ shape: some Collection = [], type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let lb = MLXArray(range.lowerBound) @@ -173,8 +173,9 @@ public enum MLXRandom { /// let array = MLXRandom.uniform(0.5 ..< 1, [50], key: key) /// ``` public static func uniform( - _ range: Range = 0 ..< 1, _ shape: [Int] = [], type: T.Type = Float.self, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + _ range: Range = 0 ..< 1, _ shape: some Collection = [], + type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let lb = MLXArray(range.lowerBound) let ub = MLXArray(range.upperBound) @@ -203,11 +204,12 @@ public enum MLXRandom { /// let value = MLXRandom.uniform(low: [0, 10], high: [10, 100], key: key) /// ``` public static func uniform( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, type: T.Type = Float.self, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + low: some ScalarOrArray, high: some ScalarOrArray, + _ shape: (some Collection)? = [Int]?.none, type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let (low, high) = toArrays(low, high) - let shape = shape ?? low.shape + let shape = shape.map { Array($0) } ?? low.shape let key = resolve(key: key) var result = mlx_array_new() @@ -233,11 +235,12 @@ public enum MLXRandom { /// let value = MLXRandom.uniform(low: [0, 10], high: [10, 100], key: key) /// ``` public static func uniform( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, dtype: DType = .float32, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + low: some ScalarOrArray, high: some ScalarOrArray, + _ shape: (some Collection)? = [Int]?.none, dtype: DType, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { let (low, high) = toArrays(low, high) - let shape = shape ?? low.shape + let shape = shape.map { Array($0) } ?? low.shape let key = resolve(key: key) var result = mlx_array_new() @@ -271,8 +274,9 @@ public enum MLXRandom { /// - scale: standard deviation of the distribution /// - key: PRNG key public static func normal( - _ shape: [Int] = [], type: T.Type = Float.self, loc: Float = 0, scale: Float = 1, - key: RandomStateOrKey? = nil, + _ shape: some Collection = [], type: T.Type = Float.self, loc: Float = 0, + scale: Float = 1, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let key = resolve(key: key) @@ -306,8 +310,8 @@ public enum MLXRandom { /// - scale: standard deviation of the distribution /// - key: PRNG key public static func normal( - _ shape: [Int] = [], dtype: DType = .float32, loc: Float = 0, scale: Float = 1, - key: RandomStateOrKey? = nil, + _ shape: some Collection = [], dtype: DType, loc: Float = 0, scale: Float = 1, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { let key = resolve(key: key) @@ -336,8 +340,8 @@ public enum MLXRandom { /// - dtype: DType of the result /// - key: PRNG key public static func multivariateNormal( - mean: MLXArray, covariance: MLXArray, shape: [Int] = [], dtype: DType = .float32, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + mean: MLXArray, covariance: MLXArray, shape: some Collection = [], dtype: DType, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { let key = resolve(key: key) var result = mlx_array_new() @@ -365,7 +369,8 @@ public enum MLXRandom { /// let array = MLXRandom.randInt(Int32(0) ..< 100, [50], key: key) /// ``` public static func randInt( - _ range: Range, _ shape: [Int] = [], key: RandomStateOrKey? = nil, + _ range: Range, _ shape: some Collection = [], + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryInteger { let lb = MLXArray(range.lowerBound) @@ -394,12 +399,13 @@ public enum MLXRandom { /// let array = MLXRandom.randInt(low: [0, 10], high: [10, 100], key: key) /// ``` public static func randInt( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, - key: RandomStateOrKey? = nil, + low: some ScalarOrArray, high: some ScalarOrArray, + _ shape: (some Collection)? = [Int]?.none, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { let (low, high) = toArrays(low, high) - let shape = shape ?? low.shape + let shape = shape.map { Array($0) } ?? low.shape let key = resolve(key: key) var result = mlx_array_new() @@ -425,11 +431,13 @@ public enum MLXRandom { /// let array = MLXRandom.randInt(low: [0, 10], high: [10, 100], type: Int8.self, key: key) /// ``` public static func randInt( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, type: T.Type, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + low: some ScalarOrArray, high: some ScalarOrArray, + _ shape: (some Collection)? = [Int]?.none, + type: T.Type, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryInteger { let (low, high) = toArrays(low, high) - let shape = shape ?? low.shape + let shape = shape.map { Array($0) } ?? low.shape let key = resolve(key: key) var result = mlx_array_new() @@ -455,7 +463,8 @@ public enum MLXRandom { /// let array = MLXRandom.bernoulli([50, 2], key: key) /// ``` public static func bernoulli( - _ shape: [Int] = [], key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + _ shape: some Collection = [], key: (some RandomStateOrKey)? = MLXArray?.none, + stream: StreamOrDevice = .default ) -> MLXArray { @@ -486,11 +495,12 @@ public enum MLXRandom { /// let array = MLXRandom.bernoulli(MLXArray(convert: [0.1, 0.5, 0.8]), key: key) /// ``` public static func bernoulli( - _ p: ScalarOrArray, _ shape: [Int]? = nil, key: RandomStateOrKey? = nil, + _ p: some ScalarOrArray, _ shape: (some Collection)? = [Int]?.none, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { let p = p.asMLXArray(dtype: .float32) - let shape = shape ?? p.shape + let shape = shape.map { Array($0) } ?? p.shape let key = resolve(key: key) var result = mlx_array_new() mlx_random_bernoulli(&result, p.ctx, shape.asInt32, shape.count, key.ctx, stream.ctx) @@ -517,8 +527,8 @@ public enum MLXRandom { /// ### See also /// - [JAX Documentation](https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal) public static func truncatedNormal( - _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, - key: RandomStateOrKey? = nil, + _ range: Range, _ shape: some Collection = [], type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let lb = MLXArray(range.lowerBound) @@ -542,8 +552,8 @@ public enum MLXRandom { /// let array = MLXRandom.truncatedNormal(0.5 ..< 1, [50], key: key) /// ``` public static func truncatedNormal( - _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, - key: RandomStateOrKey? = nil, + _ range: Range, _ shape: some Collection = [], type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let lb = MLXArray(range.lowerBound) @@ -572,11 +582,13 @@ public enum MLXRandom { /// let value = MLXRandom.truncatedNormal([0, 10], [10, 100], key: key) /// ``` public static func truncatedNormal( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, type: T.Type = Float.self, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + low: some ScalarOrArray, high: some ScalarOrArray, + _ shape: (some Collection)? = [Int]?.none, + type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let (low, high) = toArrays(low, high) - let shape = shape ?? low.shape + let shape = shape.map { Array($0) } ?? low.shape let key = resolve(key: key) var result = mlx_array_new() @@ -601,11 +613,13 @@ public enum MLXRandom { /// let value = MLXRandom.truncatedNormal([0, 10], [10, 100], key: key) /// ``` public static func truncatedNormal( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, dtype: DType = .float32, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + low: some ScalarOrArray, high: some ScalarOrArray, + _ shape: (some Collection)? = [Int]?.none, + dtype: DType, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { let (low, high) = toArrays(low, high) - let shape = shape ?? low.shape + let shape = shape.map { Array($0) } ?? low.shape let key = resolve(key: key) var result = mlx_array_new() @@ -632,7 +646,8 @@ public enum MLXRandom { /// let array = MLXRandom.gumbel([10, 5], key: key) /// ``` public static func gumbel( - _ shape: [Int] = [], type: T.Type = Float.self, key: RandomStateOrKey? = nil, + _ shape: some Collection = [], type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { let key = resolve(key: key) @@ -659,7 +674,8 @@ public enum MLXRandom { /// let array = MLXRandom.gumbel([10, 5], key: key) /// ``` public static func gumbel( - _ shape: [Int] = [], dtype: DType = .float32, key: RandomStateOrKey? = nil, + _ shape: some Collection = [], dtype: DType, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { let key = resolve(key: key) @@ -689,7 +705,8 @@ public enum MLXRandom { /// - Parameters: /// - logits: The *unnormalized* categorical distribution(s). public static func categorical( - _ logits: MLXArray, axis: Int = -1, shape: [Int]? = nil, key: RandomStateOrKey? = nil, + _ logits: MLXArray, axis: Int = -1, shape: (some Collection)? = [Int]?.none, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { let key = resolve(key: key) @@ -725,7 +742,8 @@ public enum MLXRandom { /// - Parameters: /// - logits: The *unnormalized* categorical distribution(s). public static func categorical( - _ logits: MLXArray, axis: Int = -1, count: Int, key: RandomStateOrKey? = nil, + _ logits: MLXArray, axis: Int = -1, count: Int, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { let key = resolve(key: key) @@ -745,8 +763,8 @@ public enum MLXRandom { /// - loc: mean of the distribution /// - scale: scale "b" of the distribution public static func laplace( - _ shape: [Int] = [], dtype: DType = .float32, loc: Float = 0, scale: Float = 1, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + _ shape: some Collection = [], dtype: DType, loc: Float = 0, scale: Float = 1, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { let key = resolve(key: key) var result = mlx_array_new() @@ -809,7 +827,8 @@ public func split(key: MLXArray, stream: StreamOrDevice = .default) -> (MLXArray /// let array = MLXRandom.uniform(0.5 ..< 1, [50], key: key) /// ``` public func uniform( - _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, key: RandomStateOrKey? = nil, + _ range: Range, _ shape: some Collection = [], type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.uniform(range, shape, type: type, key: key, stream: stream) @@ -824,8 +843,8 @@ public func uniform( /// let array = MLXRandom.uniform(0.5 ..< 1, [50], key: key) /// ``` public func uniform( - _ range: Range = 0 ..< 1, _ shape: [Int] = [], type: T.Type = Float.self, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + _ range: Range = 0 ..< 1, _ shape: some Collection = [], type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.uniform(range, shape, type: type, key: key, stream: stream) } @@ -845,8 +864,10 @@ public func uniform( /// let value = MLXRandom.uniform(low: [0, 10], high: [10, 100], key: key) /// ``` public func uniform( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, type: T.Type = Float.self, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + low: some ScalarOrArray, high: some ScalarOrArray, + _ shape: (some Collection)? = [Int]?.none, + type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.uniform(low: low, high: high, shape, type: type, key: key, stream: stream) } @@ -866,8 +887,9 @@ public func uniform( /// let value = MLXRandom.uniform(low: [0, 10], high: [10, 100], key: key) /// ``` public func uniform( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, dtype: DType = .float32, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + low: some ScalarOrArray, high: some ScalarOrArray, + _ shape: (some Collection)? = [Int]?.none, dtype: DType, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.uniform(low: low, high: high, shape, dtype: dtype, key: key, stream: stream) } @@ -894,8 +916,8 @@ public func uniform( /// - scale: standard deviation of the distribution /// - key: PRNG key public func normal( - _ shape: [Int] = [], type: T.Type = Float.self, loc: Float = 0, scale: Float = 1, - key: RandomStateOrKey? = nil, + _ shape: some Collection = [], type: T.Type = Float.self, loc: Float = 0, scale: Float = 1, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.normal(shape, type: type, loc: loc, scale: scale, key: key, stream: stream) @@ -923,8 +945,8 @@ public func normal( /// - scale: standard deviation of the distribution /// - key: PRNG key public func normal( - _ shape: [Int] = [], dtype: DType = .float32, loc: Float = 0, scale: Float = 1, - key: RandomStateOrKey? = nil, + _ shape: some Collection = [], dtype: DType, loc: Float = 0, scale: Float = 1, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.normal(shape, dtype: dtype, loc: loc, scale: scale, key: key, stream: stream) @@ -947,8 +969,8 @@ public func normal( /// - dtype: DType of the result /// - key: PRNG key public func multivariateNormal( - mean: MLXArray, covariance: MLXArray, shape: [Int] = [], dtype: DType = .float32, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + mean: MLXArray, covariance: MLXArray, shape: some Collection = [], dtype: DType, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.multivariateNormal( mean: mean, covariance: covariance, shape: shape, dtype: dtype, key: key, stream: stream) @@ -970,7 +992,8 @@ public func multivariateNormal( /// let array = MLXRandom.randInt(Int32(0) ..< 100, [50], key: key) /// ``` public func randInt( - _ range: Range, _ shape: [Int] = [], key: RandomStateOrKey? = nil, + _ range: Range, _ shape: some Collection = [], + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryInteger { return MLXRandom.randInt(range, shape, key: key, stream: stream) @@ -990,7 +1013,9 @@ public func randInt( /// let array = MLXRandom.randInt(low: [0, 10], high: [10, 100], key: key) /// ``` public func randInt( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, key: RandomStateOrKey? = nil, + low: some ScalarOrArray, high: some ScalarOrArray, + _ shape: (some Collection)? = [Int]?.none, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.randInt(low: low, high: high, shape, key: key, stream: stream) @@ -1011,8 +1036,10 @@ public func randInt( /// let array = MLXRandom.randInt(low: [0, 10], high: [10, 100], type: Int8.self, key: key) /// ``` public func randInt( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, type: T.Type, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + low: some ScalarOrArray, high: some ScalarOrArray, + _ shape: (some Collection)? = [Int]?.none, + type: T.Type, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryInteger { return MLXRandom.randInt(low: low, high: high, shape, type: type, key: key, stream: stream) } @@ -1032,7 +1059,8 @@ public func randInt( /// let array = MLXRandom.bernoulli([50, 2], key: key) /// ``` public func bernoulli( - _ shape: [Int] = [], key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + _ shape: some Collection = [], key: (some RandomStateOrKey)? = MLXArray?.none, + stream: StreamOrDevice = .default ) -> MLXArray { @@ -1058,7 +1086,8 @@ public func bernoulli( /// let array = MLXRandom.bernoulli(MLXArray(convert: [0.1, 0.5, 0.8]), key: key) /// ``` public func bernoulli( - _ p: ScalarOrArray, _ shape: [Int]? = nil, key: RandomStateOrKey? = nil, + _ p: some ScalarOrArray, _ shape: (some Collection)? = [Int]?.none, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.bernoulli(p, shape, key: key, stream: stream) @@ -1083,7 +1112,8 @@ public func bernoulli( /// ### See also /// - [JAX Documentation](https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal) public func truncatedNormal( - _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, key: RandomStateOrKey? = nil, + _ range: Range, _ shape: some Collection = [], type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.truncatedNormal(range, shape, type: type, key: key, stream: stream) @@ -1098,8 +1128,8 @@ public func truncatedNormal( /// let array = MLXRandom.truncatedNormal(0.5 ..< 1, [50], key: key) /// ``` public func truncatedNormal( - _ range: Range, _ shape: [Int] = [], type: T.Type = Float.self, - key: RandomStateOrKey? = nil, + _ range: Range, _ shape: some Collection = [], type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.truncatedNormal(range, shape, type: type, key: key, stream: stream) @@ -1119,8 +1149,10 @@ public func truncatedNormal( /// let value = MLXRandom.truncatedNormal([0, 10], [10, 100], key: key) /// ``` public func truncatedNormal( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, type: T.Type = Float.self, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + low: some ScalarOrArray, high: some ScalarOrArray, + _ shape: (some Collection)? = [Int]?.none, + type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.truncatedNormal( low: low, high: high, shape, type: type, key: key, stream: stream) @@ -1140,8 +1172,10 @@ public func truncatedNormal( /// let value = MLXRandom.truncatedNormal([0, 10], [10, 100], key: key) /// ``` public func truncatedNormal( - low: ScalarOrArray, high: ScalarOrArray, _ shape: [Int]? = nil, dtype: DType = .float32, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + low: some ScalarOrArray, high: some ScalarOrArray, + _ shape: (some Collection)? = [Int]?.none, + dtype: DType, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.truncatedNormal( low: low, high: high, shape, dtype: dtype, key: key, stream: stream) @@ -1162,7 +1196,8 @@ public func truncatedNormal( /// let array = MLXRandom.gumbel([10, 5], key: key) /// ``` public func gumbel( - _ shape: [Int] = [], type: T.Type = Float.self, key: RandomStateOrKey? = nil, + _ shape: some Collection = [], type: T.Type = Float.self, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray where T: HasDType, T: BinaryFloatingPoint { return MLXRandom.gumbel(shape, type: type, key: key, stream: stream) @@ -1183,7 +1218,8 @@ public func gumbel( /// let array = MLXRandom.gumbel([10, 5], key: key) /// ``` public func gumbel( - _ shape: [Int] = [], dtype: DType = .float32, key: RandomStateOrKey? = nil, + _ shape: some Collection = [], dtype: DType, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.gumbel(shape, dtype: dtype, key: key, stream: stream) @@ -1208,7 +1244,8 @@ public func gumbel( /// - Parameters: /// - logits: The *unnormalized* categorical distribution(s). public func categorical( - _ logits: MLXArray, axis: Int = -1, shape: [Int]? = nil, key: RandomStateOrKey? = nil, + _ logits: MLXArray, axis: Int = -1, shape: (some Collection)? = [Int]?.none, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.categorical(logits, axis: axis, shape: shape, key: key, stream: stream) @@ -1231,7 +1268,7 @@ public func categorical( /// - Parameters: /// - logits: The *unnormalized* categorical distribution(s). public func categorical( - _ logits: MLXArray, axis: Int = -1, count: Int, key: RandomStateOrKey? = nil, + _ logits: MLXArray, axis: Int = -1, count: Int, key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.categorical(logits, axis: axis, count: count, key: key, stream: stream) @@ -1245,8 +1282,8 @@ public func categorical( /// - loc: mean of the distribution /// - scale: scale "b" of the distribution public func laplace( - _ shape: [Int] = [], dtype: DType = .float32, loc: Float = 0, scale: Float = 1, - key: RandomStateOrKey? = nil, stream: StreamOrDevice = .default + _ shape: some Collection = [], dtype: DType = .float32, loc: Float = 0, scale: Float = 1, + key: (some RandomStateOrKey)? = MLXArray?.none, stream: StreamOrDevice = .default ) -> MLXArray { return MLXRandom.laplace(shape, dtype: dtype, loc: loc, scale: scale, key: key, stream: stream) } diff --git a/Source/MLX/State.swift b/Source/MLX/State.swift index 63e59aae..fc49dabc 100644 --- a/Source/MLX/State.swift +++ b/Source/MLX/State.swift @@ -90,7 +90,7 @@ extension MLXRandom { /// - the passed key, either an ``MLXArray`` or ``MLXRandom/RandomState`` /// - the task-local ``MLXRandom/RandomState``, see ``withRandomState(_:body:)-6i2p1`` /// - the global RandomState, ``MLXRandom/globalState`` -public func resolve(key: RandomStateOrKey?) -> MLXArray { +public func resolve(key: (some RandomStateOrKey)? = MLXArray?.none) -> MLXArray { key?.asRandomKey() ?? MLXRandom.taskLocalRandomState?.asRandomKey() ?? MLXRandom.globalState.next() } diff --git a/Source/MLX/Transforms+Eval.swift b/Source/MLX/Transforms+Eval.swift index aabdd2e7..88356a1b 100644 --- a/Source/MLX/Transforms+Eval.swift +++ b/Source/MLX/Transforms+Eval.swift @@ -24,7 +24,7 @@ public func eval(_ arrays: MLXArray...) { /// /// ### See Also /// - -public func eval(_ arrays: [MLXArray]) { +public func eval(_ arrays: some Collection) { let vector_array = new_mlx_vector_array(arrays) _ = evalLock.withLock { mlx_eval(vector_array) @@ -37,7 +37,7 @@ public func eval(_ arrays: [MLXArray]) { /// ### See Also /// - /// - ``asyncEval(_:)-6j4zg`` -public func asyncEval(_ arrays: [MLXArray]) { +public func asyncEval(_ arrays: some Collection) { let vector_array = new_mlx_vector_array(arrays) _ = evalLock.withLock { mlx_async_eval(vector_array) @@ -79,7 +79,7 @@ public func eval(_ values: Any...) { /// Evaluate one or more `MLXArray`. /// /// See ``eval(_:)`` -public func eval(_ values: [Any]) { +public func eval(_ values: some Sequence) { var arrays = [MLXArray]() for item in values { @@ -109,7 +109,7 @@ public func checkedEval(_ values: Any...) throws { /// /// ### See Also /// - -public func checkedEval(_ values: [Any]) throws { +public func checkedEval(_ values: some Sequence) throws { var arrays = [MLXArray]() for item in values { @@ -154,7 +154,7 @@ public func asyncEval(_ values: Any...) { /// Evaluate one or more `MLXArray` asynchronously. /// /// See ``asyncEval(_:)-6j4zg`` -public func asyncEval(_ values: [Any]) { +public func asyncEval(_ values: some Sequence) { var arrays = [MLXArray]() for item in values { diff --git a/Source/MLX/Transforms+Grad.swift b/Source/MLX/Transforms+Grad.swift index c64b2a05..e8d29e71 100644 --- a/Source/MLX/Transforms+Grad.swift +++ b/Source/MLX/Transforms+Grad.swift @@ -4,7 +4,9 @@ import Foundation // This file is generated by GenerateGrad. // Returns a function which computes the gradient of `f`. -public func grad(_ f: @escaping ([MLXArray]) -> [MLXArray], argumentNumbers: [Int] = [0]) -> ( +public func grad( + _ f: @escaping ([MLXArray]) -> [MLXArray], argumentNumbers: some Collection = [0] +) -> ( [MLXArray] ) -> [MLXArray] { // Converts the given function `f()` into canonical types, e.g. @@ -26,7 +28,9 @@ public func grad(_ f: @escaping ([MLXArray]) -> [MLXArray], argumentNumbers: [In } // See ``grad(_:)-r8dv`` -public func grad(_ f: @escaping ([MLXArray]) -> MLXArray, argumentNumbers: [Int] = [0]) -> ( +public func grad( + _ f: @escaping ([MLXArray]) -> MLXArray, argumentNumbers: some Collection = [0] +) -> ( [MLXArray] ) -> MLXArray { let wrappedFunction = wrapResult(wrapArguments(f)) @@ -52,7 +56,9 @@ public func grad(_ f: @escaping (MLXArray) -> MLXArray) -> (MLXArray) -> MLXArra } // Returns a function which computes the value and gradient of `f`. -public func valueAndGrad(_ f: @escaping ([MLXArray]) -> [MLXArray], argumentNumbers: [Int] = [0]) +public func valueAndGrad( + _ f: @escaping ([MLXArray]) -> [MLXArray], argumentNumbers: some Collection = [0] +) -> ([MLXArray]) -> ([MLXArray], [MLXArray]) { // diff --git a/Source/MLX/Transforms+Internal.swift b/Source/MLX/Transforms+Internal.swift index 7a05fb0d..e2a4d52d 100644 --- a/Source/MLX/Transforms+Internal.swift +++ b/Source/MLX/Transforms+Internal.swift @@ -5,7 +5,9 @@ import Foundation // see Transforms+Variants for generated grad() functions -private func valueAndGradient(apply valueAndGrad: mlx_closure_value_and_grad, arrays: [MLXArray]) +private func valueAndGradient( + apply valueAndGrad: mlx_closure_value_and_grad, arrays: some Collection +) -> ([MLXArray], [MLXArray]) { let input_vector = new_mlx_vector_array(arrays) @@ -21,9 +23,11 @@ private func valueAndGradient(apply valueAndGrad: mlx_closure_value_and_grad, ar return (mlx_vector_array_values(r0), mlx_vector_array_values(r1)) } -func buildGradient(_ f: @escaping ([MLXArray]) -> [MLXArray], argumentNumbers: [Int]) -> ( - [MLXArray] -) -> [MLXArray] { +func buildGradient(_ f: @escaping ([MLXArray]) -> [MLXArray], argumentNumbers: some Collection) + -> ( + [MLXArray] + ) -> [MLXArray] +{ { (arrays: [MLXArray]) in var vag = mlx_closure_value_and_grad_new() @@ -37,7 +41,9 @@ func buildGradient(_ f: @escaping ([MLXArray]) -> [MLXArray], argumentNumbers: [ } } -func buildValueAndGradient(_ f: @escaping ([MLXArray]) -> [MLXArray], argumentNumbers: [Int]) -> ( +func buildValueAndGradient( + _ f: @escaping ([MLXArray]) -> [MLXArray], argumentNumbers: some Collection +) -> ( [MLXArray] ) -> ([MLXArray], [MLXArray]) { { (arrays: [MLXArray]) in diff --git a/Source/MLX/Transforms+Vmap.swift b/Source/MLX/Transforms+Vmap.swift index 052f7b9f..7f7ba5fc 100644 --- a/Source/MLX/Transforms+Vmap.swift +++ b/Source/MLX/Transforms+Vmap.swift @@ -16,8 +16,8 @@ import Foundation /// - public func vmap( _ f: @escaping ([MLXArray]) -> [MLXArray], - inAxes: [Int?] = [0], - outAxes: [Int?] = [0] + inAxes: some Sequence = [0], + outAxes: some Sequence = [0] ) -> ([MLXArray]) -> [MLXArray] { { arrays in let inAxes32 = inAxes.map { Int32($0 ?? -1) } diff --git a/Source/MLX/Transforms.swift b/Source/MLX/Transforms.swift index 32d4fa7c..ab37e688 100644 --- a/Source/MLX/Transforms.swift +++ b/Source/MLX/Transforms.swift @@ -15,7 +15,8 @@ import Foundation /// should be the same in number, shape and type as the inputs of `f`, e.g. the `primals` /// - Returns: array of the Jacobian-vector products which is the same in number, shape and type of the outputs of `f` public func jvp( - _ f: @escaping ([MLXArray]) -> [MLXArray], primals: [MLXArray], tangents: [MLXArray] + _ f: @escaping ([MLXArray]) -> [MLXArray], primals: some Collection, + tangents: some Collection ) -> ([MLXArray], [MLXArray]) { let primals_mlx = new_mlx_vector_array(primals) defer { mlx_vector_array_free(primals_mlx) } @@ -47,7 +48,8 @@ public func jvp( /// should be the same in number, shape and type as the outputs of `f` /// - Returns: array of the vector-Jacobian products which is the same in number, shape and type of the outputs of `f` public func vjp( - _ f: @escaping ([MLXArray]) -> [MLXArray], primals: [MLXArray], cotangents: [MLXArray] + _ f: @escaping ([MLXArray]) -> [MLXArray], primals: some Collection, + cotangents: some Collection ) -> ([MLXArray], [MLXArray]) { let primals_mlx = new_mlx_vector_array(primals) defer { mlx_vector_array_free(primals_mlx) } diff --git a/Source/MLXOptimizers/Optimizers.swift b/Source/MLXOptimizers/Optimizers.swift index 48041b5b..e0d2ee37 100644 --- a/Source/MLXOptimizers/Optimizers.swift +++ b/Source/MLXOptimizers/Optimizers.swift @@ -664,7 +664,7 @@ open class Adafactor: OptimizerBase { /// - Parameters: /// - gradients: an array of MLXArray /// - maxNorm: the maximum allowed global norm of th gradients -public func clipGradNorm(gradients: [MLXArray], maxNorm: Float) -> ([MLXArray], MLXArray) { +public func clipGradNorm(gradients: some Collection, maxNorm: Float) -> ([MLXArray], MLXArray) { let normSquared = gradients.reduce(MLXArray(0)) { $0 + $1.square().sum() } let totalNorm = sqrt(normSquared) let normalizer = maxNorm / (totalNorm + 1e-6)