Skip to content

Commit 7c306ae

Browse files
committed
genericized function arguments
- replaced the use of concrete arrays with Collection and Sequence, when applicable - replaced the use of existentials with generics - removed the default arguments for `dtype:` to resolve the ambiguities arisen after the changes
1 parent 8f9f747 commit 7c306ae

22 files changed

+330
-245
lines changed

Source/MLX/ArrayAt.swift

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public struct ArrayAt {
3434
/// ### See Also
3535
/// - ``MLXArray/at``
3636
/// - ``ArrayAtIndices``
37-
public subscript(indices: MLXArrayIndex..., stream stream: StreamOrDevice = .default)
37+
public subscript(indices: any MLXArrayIndex..., stream stream: StreamOrDevice = .default)
3838
-> ArrayAtIndices
3939
{
4040
get {
@@ -58,7 +58,7 @@ public struct ArrayAt {
5858
/// ### See Also
5959
/// - ``MLXArray/at``
6060
/// - ``ArrayAtIndices``
61-
public subscript(indices: [MLXArrayIndex], stream stream: StreamOrDevice = .default)
61+
public subscript(indices: some Sequence<any MLXArrayIndex>, stream stream: StreamOrDevice = .default)
6262
-> ArrayAtIndices
6363
{
6464
get {
@@ -99,7 +99,7 @@ public struct ArrayAtIndices {
9999
///
100100
/// ### See Also
101101
/// - ``MLXArray/at``
102-
public func add(_ values: ScalarOrArray) -> MLXArray {
102+
public func add(_ values: some ScalarOrArray) -> MLXArray {
103103
let values = values.asMLXArray(dtype: array.dtype)
104104
let (indices, update, axes) = scatterArguments(
105105
src: array, operations: indexOperations, update: values, stream: stream)
@@ -128,7 +128,7 @@ public struct ArrayAtIndices {
128128
///
129129
/// ### See Also
130130
/// - ``MLXArray/at``
131-
public func subtract(_ values: ScalarOrArray) -> MLXArray {
131+
public func subtract(_ values: some ScalarOrArray) -> MLXArray {
132132
add(-values.asMLXArray(dtype: array.dtype))
133133
}
134134

@@ -142,7 +142,7 @@ public struct ArrayAtIndices {
142142
///
143143
/// ### See Also
144144
/// - ``MLXArray/at``
145-
public func multiply(_ values: ScalarOrArray) -> MLXArray {
145+
public func multiply(_ values: some ScalarOrArray) -> MLXArray {
146146
let values = values.asMLXArray(dtype: array.dtype)
147147
let (indices, update, axes) = scatterArguments(
148148
src: array, operations: indexOperations, update: values, stream: stream)
@@ -171,7 +171,7 @@ public struct ArrayAtIndices {
171171
///
172172
/// ### See Also
173173
/// - ``MLXArray/at``
174-
public func divide(_ values: ScalarOrArray) -> MLXArray {
174+
public func divide(_ values: some ScalarOrArray) -> MLXArray {
175175
multiply(values.asMLXArray(dtype: array.dtype).reciprocal())
176176
}
177177

@@ -185,7 +185,7 @@ public struct ArrayAtIndices {
185185
///
186186
/// ### See Also
187187
/// - ``MLXArray/at``
188-
public func minimum(_ values: ScalarOrArray) -> MLXArray {
188+
public func minimum(_ values: some ScalarOrArray) -> MLXArray {
189189
let values = values.asMLXArray(dtype: array.dtype)
190190
let (indices, update, axes) = scatterArguments(
191191
src: array, operations: indexOperations, update: values, stream: stream)
@@ -214,7 +214,7 @@ public struct ArrayAtIndices {
214214
///
215215
/// ### See Also
216216
/// - ``MLXArray/at``
217-
public func maximum(_ values: ScalarOrArray) -> MLXArray {
217+
public func maximum(_ values: some ScalarOrArray) -> MLXArray {
218218
let values = values.asMLXArray(dtype: array.dtype)
219219
let (indices, update, axes) = scatterArguments(
220220
src: array, operations: indexOperations, update: values, stream: stream)

Source/MLX/Cmlx+Util.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import Cmlx
44
import Foundation
55

66
// return a +1 mlx_vector_array containing the given arrays
7-
func new_mlx_vector_array(_ arrays: [MLXArray]) -> mlx_vector_array {
7+
func new_mlx_vector_array(_ arrays: some Collection<MLXArray>) -> mlx_vector_array {
88
withExtendedLifetime(arrays) {
99
mlx_vector_array_new_data(arrays.map { $0.ctx }, arrays.count)
1010
}

Source/MLX/FFT.swift

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ public enum MLXFFT {
6262
/// ### See Also
6363
/// - <doc:MLXFFT>
6464
public static func fft2(
65-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1],
65+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
66+
axes: (some Collection<Int>)? = [-2, -1],
6667
stream: StreamOrDevice = .default
6768
) -> MLXArray {
6869
fftn(array, s: s, axes: axes, stream: stream)
@@ -81,7 +82,8 @@ public enum MLXFFT {
8182
/// ### See Also
8283
/// - <doc:MLXFFT>
8384
public static func ifft2(
84-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1],
85+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
86+
axes: (some Collection<Int>)? = [-2, -1],
8587
stream: StreamOrDevice = .default
8688
) -> MLXArray {
8789
ifftn(array, s: s, axes: axes, stream: stream)
@@ -100,7 +102,8 @@ public enum MLXFFT {
100102
/// ### See Also
101103
/// - <doc:MLXFFT>
102104
public static func fftn(
103-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
105+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
106+
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
104107
) -> MLXArray {
105108
var result = mlx_array_new()
106109
if let s, let axes {
@@ -146,7 +149,8 @@ public enum MLXFFT {
146149
/// ### See Also
147150
/// - <doc:MLXFFT>
148151
public static func ifftn(
149-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
152+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
153+
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
150154
) -> MLXArray {
151155
var result = mlx_array_new()
152156
if let s, let axes {
@@ -244,7 +248,8 @@ public enum MLXFFT {
244248
/// ### See Also
245249
/// - <doc:MLXFFT>
246250
public static func rfft2(
247-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1],
251+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
252+
axes: (some Collection<Int>)? = [-2, -1],
248253
stream: StreamOrDevice = .default
249254
) -> MLXArray {
250255
rfftn(array, s: s, axes: axes, stream: stream)
@@ -268,7 +273,8 @@ public enum MLXFFT {
268273
/// ### See Also
269274
/// - <doc:MLXFFT>
270275
public static func irfft2(
271-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1],
276+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
277+
axes: (some Collection<Int>)? = [-2, -1],
272278
stream: StreamOrDevice = .default
273279
) -> MLXArray {
274280
irfftn(array, s: s, axes: axes, stream: stream)
@@ -291,7 +297,8 @@ public enum MLXFFT {
291297
/// ### See Also
292298
/// - <doc:MLXFFT>
293299
public static func rfftn(
294-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
300+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
301+
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
295302
) -> MLXArray {
296303
var result = mlx_array_new()
297304
if let s, let axes {
@@ -342,7 +349,8 @@ public enum MLXFFT {
342349
/// ### See Also
343350
/// - <doc:MLXFFT>
344351
public static func irfftn(
345-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
352+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
353+
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
346354
) -> MLXArray {
347355
var result = mlx_array_new()
348356
if let s, let axes {
@@ -428,7 +436,9 @@ public func ifft(
428436
/// ### See Also
429437
/// - <doc:MLXFFT>
430438
public func fft2(
431-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], stream: StreamOrDevice = .default
439+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
440+
axes: (some Collection<Int>)? = [-2, -1],
441+
stream: StreamOrDevice = .default
432442
) -> MLXArray {
433443
MLXFFT.fft2(array, s: s, axes: axes, stream: stream)
434444
}
@@ -446,7 +456,9 @@ public func fft2(
446456
/// ### See Also
447457
/// - <doc:MLXFFT>
448458
public func ifft2(
449-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], stream: StreamOrDevice = .default
459+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
460+
axes: (some Collection<Int>)? = [-2, -1],
461+
stream: StreamOrDevice = .default
450462
) -> MLXArray {
451463
MLXFFT.ifft2(array, s: s, axes: axes, stream: stream)
452464
}
@@ -464,7 +476,8 @@ public func ifft2(
464476
/// ### See Also
465477
/// - <doc:MLXFFT>
466478
public func fftn(
467-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
479+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
480+
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
468481
) -> MLXArray {
469482
MLXFFT.fftn(array, s: s, axes: axes, stream: stream)
470483
}
@@ -482,7 +495,8 @@ public func fftn(
482495
/// ### See Also
483496
/// - <doc:MLXFFT>
484497
public func ifftn(
485-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
498+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
499+
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
486500
) -> MLXArray {
487501
MLXFFT.ifftn(array, s: s, axes: axes, stream: stream)
488502
}
@@ -546,7 +560,9 @@ public func irfft(
546560
/// ### See Also
547561
/// - <doc:MLXFFT>
548562
public func rfft2(
549-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], stream: StreamOrDevice = .default
563+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
564+
axes: (some Collection<Int>)? = [-2, -1],
565+
stream: StreamOrDevice = .default
550566
) -> MLXArray {
551567
MLXFFT.rfft2(array, s: s, axes: axes, stream: stream)
552568
}
@@ -569,7 +585,9 @@ public func rfft2(
569585
/// ### See Also
570586
/// - <doc:MLXFFT>
571587
public func irfft2(
572-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = [-2, -1], stream: StreamOrDevice = .default
588+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
589+
axes: (some Collection<Int>)? = [-2, -1],
590+
stream: StreamOrDevice = .default
573591
) -> MLXArray {
574592
MLXFFT.irfft2(array, s: s, axes: axes, stream: stream)
575593
}
@@ -591,7 +609,8 @@ public func irfft2(
591609
/// ### See Also
592610
/// - <doc:MLXFFT>
593611
public func rfftn(
594-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
612+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
613+
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
595614
) -> MLXArray {
596615
MLXFFT.rfftn(array, s: s, axes: axes, stream: stream)
597616
}
@@ -614,7 +633,8 @@ public func rfftn(
614633
/// ### See Also
615634
/// - <doc:MLXFFT>
616635
public func irfftn(
617-
_ array: MLXArray, s: [Int]? = nil, axes: [Int]? = nil, stream: StreamOrDevice = .default
636+
_ array: MLXArray, s: (some Collection<Int>)? = [Int]?.none,
637+
axes: (some Collection<Int>)? = [Int]?.none, stream: StreamOrDevice = .default
618638
) -> MLXArray {
619639
MLXFFT.irfftn(array, s: s, axes: axes, stream: stream)
620640
}

0 commit comments

Comments
 (0)