Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions Source/MLX/ArrayAt.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public struct ArrayAt {
/// ### See Also
/// - ``MLXArray/at``
/// - ``ArrayAtIndices``
public subscript(indices: MLXArrayIndex..., stream stream: StreamOrDevice = .default)
public subscript(indices: any MLXArrayIndex..., stream stream: StreamOrDevice = .default)
-> ArrayAtIndices
{
get {
Expand All @@ -58,7 +58,9 @@ public struct ArrayAt {
/// ### See Also
/// - ``MLXArray/at``
/// - ``ArrayAtIndices``
public subscript(indices: [MLXArrayIndex], stream stream: StreamOrDevice = .default)
public subscript(indices: some Sequence<any MLXArrayIndex>,
stream stream: StreamOrDevice = .default
)
-> ArrayAtIndices
{
get {
Expand Down Expand Up @@ -99,7 +101,7 @@ public struct ArrayAtIndices {
///
/// ### See Also
/// - ``MLXArray/at``
public func add(_ values: ScalarOrArray) -> MLXArray {
public func add(_ values: some ScalarOrArray) -> MLXArray {
let values = values.asMLXArray(dtype: array.dtype)
let (indices, update, axes) = scatterArguments(
src: array, operations: indexOperations, update: values, stream: stream)
Expand Down Expand Up @@ -128,7 +130,7 @@ public struct ArrayAtIndices {
///
/// ### See Also
/// - ``MLXArray/at``
public func subtract(_ values: ScalarOrArray) -> MLXArray {
public func subtract(_ values: some ScalarOrArray) -> MLXArray {
add(-values.asMLXArray(dtype: array.dtype))
}

Expand All @@ -142,7 +144,7 @@ public struct ArrayAtIndices {
///
/// ### See Also
/// - ``MLXArray/at``
public func multiply(_ values: ScalarOrArray) -> MLXArray {
public func multiply(_ values: some ScalarOrArray) -> MLXArray {
let values = values.asMLXArray(dtype: array.dtype)
let (indices, update, axes) = scatterArguments(
src: array, operations: indexOperations, update: values, stream: stream)
Expand Down Expand Up @@ -171,7 +173,7 @@ public struct ArrayAtIndices {
///
/// ### See Also
/// - ``MLXArray/at``
public func divide(_ values: ScalarOrArray) -> MLXArray {
public func divide(_ values: some ScalarOrArray) -> MLXArray {
multiply(values.asMLXArray(dtype: array.dtype).reciprocal())
}

Expand All @@ -185,7 +187,7 @@ public struct ArrayAtIndices {
///
/// ### See Also
/// - ``MLXArray/at``
public func minimum(_ values: ScalarOrArray) -> MLXArray {
public func minimum(_ values: some ScalarOrArray) -> MLXArray {
let values = values.asMLXArray(dtype: array.dtype)
let (indices, update, axes) = scatterArguments(
src: array, operations: indexOperations, update: values, stream: stream)
Expand Down Expand Up @@ -214,7 +216,7 @@ public struct ArrayAtIndices {
///
/// ### See Also
/// - ``MLXArray/at``
public func maximum(_ values: ScalarOrArray) -> MLXArray {
public func maximum(_ values: some ScalarOrArray) -> MLXArray {
let values = values.asMLXArray(dtype: array.dtype)
let (indices, update, axes) = scatterArguments(
src: array, operations: indexOperations, update: values, stream: stream)
Expand Down
2 changes: 1 addition & 1 deletion Source/MLX/Cmlx+Util.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import Cmlx
import Foundation

// return a +1 mlx_vector_array containing the given arrays
func new_mlx_vector_array(_ arrays: [MLXArray]) -> mlx_vector_array {
func new_mlx_vector_array(_ arrays: some Collection<MLXArray>) -> mlx_vector_array {
withExtendedLifetime(arrays) {
mlx_vector_array_new_data(arrays.map { $0.ctx }, arrays.count)
}
Expand Down
52 changes: 36 additions & 16 deletions Source/MLX/FFT.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ public enum MLXFFT {
/// ### See Also
/// - <doc:MLXFFT>
public static func fft2(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1],
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [-2, -1],
stream: StreamOrDevice = .default
) -> MLXArray {
fftn(array, s: s, axes: axes, stream: stream)
Expand All @@ -81,7 +82,8 @@ public enum MLXFFT {
/// ### See Also
/// - <doc:MLXFFT>
public static func ifft2(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1],
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [-2, -1],
stream: StreamOrDevice = .default
) -> MLXArray {
ifftn(array, s: s, axes: axes, stream: stream)
Expand All @@ -100,7 +102,8 @@ public enum MLXFFT {
/// ### See Also
/// - <doc:MLXFFT>
public static func fftn(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
if let s, let axes {
Expand Down Expand Up @@ -146,7 +149,8 @@ public enum MLXFFT {
/// ### See Also
/// - <doc:MLXFFT>
public static func ifftn(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
if let s, let axes {
Expand Down Expand Up @@ -244,7 +248,8 @@ public enum MLXFFT {
/// ### See Also
/// - <doc:MLXFFT>
public static func rfft2(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1],
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [-2, -1],
stream: StreamOrDevice = .default
) -> MLXArray {
rfftn(array, s: s, axes: axes, stream: stream)
Expand All @@ -268,7 +273,8 @@ public enum MLXFFT {
/// ### See Also
/// - <doc:MLXFFT>
public static func irfft2(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1],
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [-2, -1],
stream: StreamOrDevice = .default
) -> MLXArray {
irfftn(array, s: s, axes: axes, stream: stream)
Expand All @@ -291,7 +297,8 @@ public enum MLXFFT {
/// ### See Also
/// - <doc:MLXFFT>
public static func rfftn(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
if let s, let axes {
Expand Down Expand Up @@ -342,7 +349,8 @@ public enum MLXFFT {
/// ### See Also
/// - <doc:MLXFFT>
public static func irfftn(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
) -> MLXArray {
var result = mlx_array_new()
if let s, let axes {
Expand Down Expand Up @@ -428,7 +436,9 @@ public func ifft(
/// ### See Also
/// - <doc:MLXFFT>
public func fft2(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], stream: StreamOrDevice = .default
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [-2, -1],
stream: StreamOrDevice = .default
) -> MLXArray {
MLXFFT.fft2(array, s: s, axes: axes, stream: stream)
}
Expand All @@ -446,7 +456,9 @@ public func fft2(
/// ### See Also
/// - <doc:MLXFFT>
public func ifft2(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], stream: StreamOrDevice = .default
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [-2, -1],
stream: StreamOrDevice = .default
) -> MLXArray {
MLXFFT.ifft2(array, s: s, axes: axes, stream: stream)
}
Expand All @@ -464,7 +476,8 @@ public func ifft2(
/// ### See Also
/// - <doc:MLXFFT>
public func fftn(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
) -> MLXArray {
MLXFFT.fftn(array, s: s, axes: axes, stream: stream)
}
Expand All @@ -482,7 +495,8 @@ public func fftn(
/// ### See Also
/// - <doc:MLXFFT>
public func ifftn(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
) -> MLXArray {
MLXFFT.ifftn(array, s: s, axes: axes, stream: stream)
}
Expand Down Expand Up @@ -546,7 +560,9 @@ public func irfft(
/// ### See Also
/// - <doc:MLXFFT>
public func rfft2(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], stream: StreamOrDevice = .default
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [-2, -1],
stream: StreamOrDevice = .default
) -> MLXArray {
MLXFFT.rfft2(array, s: s, axes: axes, stream: stream)
}
Expand All @@ -569,7 +585,9 @@ public func rfft2(
/// ### See Also
/// - <doc:MLXFFT>
public func irfft2(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], stream: StreamOrDevice = .default
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [-2, -1],
stream: StreamOrDevice = .default
) -> MLXArray {
MLXFFT.irfft2(array, s: s, axes: axes, stream: stream)
}
Expand All @@ -591,7 +609,8 @@ public func irfft2(
/// ### See Also
/// - <doc:MLXFFT>
public func rfftn(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
) -> MLXArray {
MLXFFT.rfftn(array, s: s, axes: axes, stream: stream)
}
Expand All @@ -614,7 +633,8 @@ public func rfftn(
/// ### See Also
/// - <doc:MLXFFT>
public func irfftn(
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
) -> MLXArray {
MLXFFT.irfftn(array, s: s, axes: axes, stream: stream)
}
Loading