Skip to content

Commit abe1092

Browse files
authored
rename update -> _updateInternal to discourage callers (#230)
1 parent 1f1f922 commit abe1092

File tree

7 files changed

+21
-24
lines changed

7 files changed

+21
-24
lines changed

Source/MLX/MLXArray+Indexing.swift

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,6 @@ extension MLXArray {
6161
}
6262
}
6363

64-
/// Replace the interior ctx (`mlx_array` pointer) with a new value by transferring ownership
65-
@inline(__always)
66-
func update(_ ctx: mlx_array) {
67-
mlx_array_set(&self.ctx, ctx)
68-
}
69-
7064
/// allow addressing as a positive index or negative (from end) using given axis
7165
@inlinable
7266
func resolve(index: Int, axis: Int) -> MLXArray {
@@ -161,7 +155,7 @@ extension MLXArray {
161155
.broadcast(to: broadcastShape.asInt32)
162156

163157
let indices = [resolve(index: index, axis: axis)]
164-
self.update(scattered(indices: indices, updates: expanded, axes: [axis.int32]))
158+
self._updateInternal(scattered(indices: indices, updates: expanded, axes: [axis.int32]))
165159
}
166160
}
167161

@@ -350,7 +344,7 @@ extension MLXArray {
350344
let update = newValue.broadcast(to: broadcastShape).reshaped(updateShape)
351345

352346
let axes = arange(axis + 1)
353-
self.update(scattered(indices: arrayIndices, updates: update, axes: axes))
347+
self._updateInternal(scattered(indices: arrayIndices, updates: update, axes: axes))
354348
}
355349
}
356350

@@ -380,7 +374,7 @@ extension MLXArray {
380374
}
381375
set {
382376
if let result = updateSlice(src: self, operations: operations, update: newValue) {
383-
self.update(result)
377+
self._updateInternal(result)
384378
return
385379
}
386380

@@ -393,11 +387,11 @@ extension MLXArray {
393387
var result = mlx_array_new()
394388
mlx_scatter(
395389
&result, self.ctx, indices_vector, update.ctx, axes, axes.count, stream.ctx)
396-
self.update(result)
390+
mlx_array_set(&self.ctx, result)
397391
mlx_array_free(result)
398392
return
399393
} else {
400-
self.update(update)
394+
self._updateInternal(update)
401395
return
402396
}
403397
}

Source/MLX/MLXArray+Ops.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ extension MLXArray {
6363
/// - <doc:arithmetic>
6464
/// - ``add(_:_:stream:)``
6565
public static func += (lhs: inout MLXArray, rhs: MLXArray) {
66-
lhs.update(lhs + rhs)
66+
lhs._updateInternal(lhs + rhs)
6767
}
6868

6969
/// Element-wise addition with a ``ScalarOrArray`` (scalar) argument.
@@ -133,7 +133,7 @@ extension MLXArray {
133133
/// - <doc:arithmetic>
134134
/// - ``subtract(_:_:stream:)``
135135
public static func -= (lhs: inout MLXArray, rhs: MLXArray) {
136-
lhs.update(lhs - rhs)
136+
lhs._updateInternal(lhs - rhs)
137137
}
138138

139139
/// Element-wise subtraction with a ``ScalarOrArray`` (scalar) argument.
@@ -224,7 +224,7 @@ extension MLXArray {
224224
/// - ``matmul(_:stream:)``
225225
/// - ``matmul(_:_:stream:)``
226226
public static func *= (lhs: inout MLXArray, rhs: MLXArray) {
227-
lhs.update(lhs * rhs)
227+
lhs._updateInternal(lhs * rhs)
228228
}
229229

230230
/// Element-wise multiplication with a ``ScalarOrArray`` (scalar) argument.
@@ -340,7 +340,7 @@ extension MLXArray {
340340
/// - ``divide(_:_:stream:)``
341341
/// - ``floorDivide(_:_:stream:)``
342342
public static func /= (lhs: inout MLXArray, rhs: MLXArray) {
343-
lhs.update(lhs / rhs)
343+
lhs._updateInternal(lhs / rhs)
344344
}
345345

346346
/// Element-wise division with a ``ScalarOrArray`` (scalar) argument.

Source/MLX/MLXArray.swift

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,11 @@ public final class MLXArray {
539539
mlx_array_eval(ctx)
540540
}
541541

542-
/// Replace the contents with a reference to a new array.
543-
public func update(_ array: MLXArray) {
542+
/// Replace the contents with a reference to a new array (INTERNAL).
543+
///
544+
/// Note: this is an implementation detail and only visible because of the need to call it from
545+
/// other `mlx-swift` modules.
546+
public func _updateInternal(_ array: MLXArray) {
544547
mlx_array_set(&self.ctx, array.ctx)
545548
}
546549

Source/MLX/Protocols.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import Foundation
55
/// An object that can provide a list of the ``MLXArray`` in its inner state.
66
///
77
/// Note that the array itself is not a reference to the inner state, but the ``MLXArray`` instances
8-
/// can be ``MLXArray/update(_:)`` to mutate the inner state. The exact working is an
8+
/// can be ``MLXArray/_updateInternal(_:)`` to mutate the inner state. The exact working is an
99
/// implemention detail for MLX and should not be depended on by outside callers.
1010
///
1111
/// ### See Also

Source/MLX/Transforms+Compile.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ final class CompiledFunction: @unchecked (Sendable) {
6262

6363
// replace the inner state with the tracers
6464
for (s, tracer) in zip(stateInputs, tracers[argumentsCount...]) {
65-
s.update(tracer)
65+
s._updateInternal(tracer)
6666
}
6767

6868
// call the function with the tracer arguments
@@ -74,7 +74,7 @@ final class CompiledFunction: @unchecked (Sendable) {
7474

7575
// put the original values back in the state
7676
for (s, saved) in zip(stateInputs, savedStateInputs) {
77-
s.update(saved)
77+
s._updateInternal(saved)
7878
}
7979

8080
// return the result of the function and the state
@@ -106,7 +106,7 @@ final class CompiledFunction: @unchecked (Sendable) {
106106
let stateOutput = outputs.flatMap { $0.innerState() }
107107

108108
for (s, newValues) in zip(stateOutput, resultsPlusStateOutput.suffix(stateOutput.count)) {
109-
s.update(newValues)
109+
s._updateInternal(newValues)
110110
}
111111

112112
let resultLength = resultsPlusStateOutput.count - stateOutput.count

Source/MLXNN/Module.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ open class Module {
448448
throw UpdateError.mismatchedSize(
449449
key: key, expectedShape: p.shape, actualShape: newArray.shape)
450450
}
451-
p.update(newArray)
451+
p._updateInternal(newArray)
452452

453453
case (.value(.parameters(let p)), .none):
454454
if Self.parameterIsValid(key) {

Source/MLXNN/Normalization.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,8 @@ open class BatchNorm: Module, UnaryLayer {
336336

337337
if self.training, let runningMean, let runningVar {
338338
let mu = momentum
339-
runningMean.update((1 - mu) * runningMean + mu * mean)
340-
runningVar.update((1 - mu) * runningVar + mu * variance)
339+
runningMean._updateInternal((1 - mu) * runningMean + mu * mean)
340+
runningVar._updateInternal((1 - mu) * runningVar + mu * variance)
341341

342342
} else if let runningMean, let runningVar {
343343
mean = runningMean

0 commit comments

Comments
 (0)