Skip to content

Commit b94473a

Browse files
authored
partial fix for #237 -- switching Device exhausts many resources (#242)
* partial fix for #237 -- switching Device exhausts many resources * remove debug code * remove debug code * comments
1 parent 50ec1bf commit b94473a

File tree

3 files changed

+126
-22
lines changed

3 files changed

+126
-22
lines changed

Source/MLX/Device.swift

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,19 @@ public enum DeviceType: String, Hashable, Sendable {
2828
public final class Device: @unchecked Sendable, Equatable {
2929

3030
let ctx: mlx_device
31+
let defaultStream: Stream
3132

3233
init(_ ctx: mlx_device) {
3334
self.ctx = ctx
35+
36+
var deviceType = MLX_GPU
37+
mlx_device_get_type(&deviceType, ctx)
38+
self.defaultStream =
39+
switch deviceType {
40+
case MLX_CPU: .cpu
41+
case MLX_GPU: .gpu
42+
default: .gpu
43+
}
3444
}
3545

3646
public init(_ deviceType: DeviceType, index: Int32 = 0) {
@@ -42,19 +52,32 @@ public final class Device: @unchecked Sendable, Equatable {
4252
cDeviceType = MLX_GPU
4353
}
4454
self.ctx = mlx_device_new_type(cDeviceType, index)
55+
self.defaultStream =
56+
switch deviceType {
57+
case .cpu: .cpu
58+
case .gpu: .gpu
59+
}
4560
}
4661

47-
public init() {
62+
@available(*, deprecated, message: "please use defaultDevice()")
63+
public convenience init() {
4864
var ctx = mlx_device_new()
4965
mlx_get_default_device(&ctx)
50-
self.ctx = ctx
66+
self.init(ctx)
5167
}
5268

5369
deinit {
5470
mlx_device_free(ctx)
5571
}
5672

73+
/// static CPU device
74+
///
75+
/// See ``withDefaultDevice(_:_:)``
5776
static public let cpu: Device = Device(.cpu)
77+
78+
/// static GPU device
79+
///
80+
/// See ``withDefaultDevice(_:_:)``
5881
static public let gpu: Device = Device(.gpu)
5982

6083
public var deviceType: DeviceType? {
@@ -67,28 +90,50 @@ public final class Device: @unchecked Sendable, Equatable {
6790
}
6891
}
6992

93+
// support for global default device
7094
static let _lock = NSLock()
7195
#if swift(>=5.10)
72-
nonisolated(unsafe) static var _defaultDevice = gpu
73-
nonisolated(unsafe) static var _defaultStream = Stream(gpu)
96+
nonisolated(unsafe) static var _defaultDevice: Device?
7497
#else
75-
static var _defaultDevice = gpu
76-
static var _defaultStream = Stream(gpu)
98+
static var _defaultDevice: Device?
7799
#endif
78100

79-
static public func defaultDevice() -> Device {
101+
@TaskLocal static var _tlDefaultDevice = _resolveGlobalDefaultDevice()
102+
103+
private static func _resolveGlobalDefaultDevice() -> Device {
80104
_lock.withLock {
81-
_defaultDevice
105+
_defaultDevice ?? .gpu
82106
}
83107
}
84108

109+
/// Return the current default device.
110+
///
111+
/// This is used by ``StreamOrDevice/default`` -- the default stream parameter
112+
/// to most functions.
113+
static public func defaultDevice() -> Device {
114+
_tlDefaultDevice
115+
}
116+
117+
/// Use a device scoped to a task.
118+
static public func withDefaultDevice<R>(
119+
_ device: Device, _ body: () throws -> R
120+
) rethrows -> R {
121+
try $_tlDefaultDevice.withValue(device, operation: body)
122+
}
123+
124+
/// Use a device scoped to a task.
125+
static public func withDefaultDevice<R>(
126+
_ device: Device, _ body: () async throws -> R
127+
) async rethrows -> R {
128+
try await $_tlDefaultDevice.withValue(device, operation: body)
129+
}
130+
131+
/// Return the current default stream.
85132
static func defaultStream() -> Stream {
86-
_lock.withLock {
87-
_defaultStream
88-
}
133+
_tlDefaultDevice.defaultStream
89134
}
90135

91-
/// Set the default device.
136+
/// Set the default device globally. Prefer the scoped version, ``withDefaultDevice(_:_:)``.
92137
///
93138
/// For example:
94139
///
@@ -99,12 +144,19 @@ public final class Device: @unchecked Sendable, Equatable {
99144
/// By default this is ``gpu``.
100145
///
101146
/// ### See Also
147+
/// - ``withDefaultDevice(_:_:)``
102148
/// - ``StreamOrDevice/default``
103-
static public func setDefault(device: Device) {
149+
@available(*, deprecated, message: "please use withDefaultDevice()")
150+
static public func setDefault(device: Device?) {
104151
_lock.withLock {
105-
mlx_set_default_device(device.ctx)
152+
if let device {
153+
// sets the mlx core default device -- only used
154+
// by the deprecated init(). this isn't thread
155+
// safe or really usable across tasks/threads
156+
// but is kept for backward compatibility
157+
mlx_set_default_device(device.ctx)
158+
}
106159
_defaultDevice = device
107-
_defaultStream = Stream(device)
108160
}
109161
}
110162

@@ -139,6 +191,7 @@ extension Device: CustomStringConvertible {
139191
/// - Parameters:
140192
/// - device: device to be used
141193
/// - fn: function to be executed
194+
@available(*, deprecated, message: "please use Device.withDefaultDevice()")
142195
public func using<R>(device: Device, fn: () throws -> R) rethrows -> R {
143-
try Stream.withNewDefaultStream(device: device, fn)
196+
try Device.withDefaultDevice(device, fn)
144197
}

Source/MLX/Stream.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public struct StreamOrDevice: Sendable, CustomStringConvertible, Equatable {
3535
/// This will be ``Device/gpu`` unless ``Device/setDefault(device:)``
3636
/// sets it otherwise.
3737
public static var `default`: StreamOrDevice {
38-
StreamOrDevice(Stream.defaultStream)
38+
StreamOrDevice(Stream.defaultStream ?? Device.defaultStream())
3939
}
4040

4141
public static func device(_ device: Device) -> StreamOrDevice {
@@ -83,10 +83,10 @@ public final class Stream: @unchecked Sendable, Equatable {
8383

8484
let ctx: mlx_stream
8585

86-
public static let gpu = Stream(.gpu)
87-
public static let cpu = Stream(.cpu)
86+
public static let gpu = Stream(mlx_default_gpu_stream_new())
87+
public static let cpu = Stream(mlx_default_cpu_stream_new())
8888

89-
@TaskLocal static var defaultStream = Stream()
89+
@TaskLocal static var defaultStream: Stream?
9090

9191
/// Set the ``StreamOrDevice/default`` scoped to a Task.
9292
public static func withNewDefaultStream<R>(device: Device? = nil, _ body: () throws -> R)

Tests/MLXTests/StreamTests.swift

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,65 @@ class StreamTests: XCTestCase {
3131
func testUsingDevice() {
3232
let defaultDevice = Device.defaultDevice()
3333

34-
using(device: .cpu) {
34+
Device.withDefaultDevice(.cpu) {
35+
// these _should_ be the same
36+
XCTAssertTrue(Device.defaultDevice().description.contains("cpu"))
3537
XCTAssertTrue(StreamOrDevice.default.description.contains("cpu"))
3638
}
3739
XCTAssertEqual(defaultDevice, Device.defaultDevice())
3840

39-
using(device: .gpu) {
41+
Device.withDefaultDevice(.gpu) {
42+
XCTAssertTrue(Device.defaultDevice().description.contains("gpu"))
4043
XCTAssertTrue(StreamOrDevice.default.description.contains("gpu"))
4144
}
4245
XCTAssertTrue(StreamOrDevice.default.description.contains("gpu"))
4346
}
47+
48+
func testSetUnsetDefaultDevice() {
49+
// Issue #237 -- setting an unsetting the default device in a loop
50+
// exhausts many resources
51+
for _ in 1 ..< 10000 {
52+
let defaultDevice = MLX.Device.defaultDevice()
53+
MLX.Device.setDefault(device: .cpu)
54+
defer {
55+
MLX.Device.setDefault(device: defaultDevice)
56+
}
57+
58+
let x = MLXArray(1)
59+
let _ = x * x
60+
}
61+
print("here")
62+
}
63+
64+
func testWithDefaultDevice() {
65+
// Issue #237 -- scoped variant
66+
for _ in 1 ..< 10000 {
67+
Device.withDefaultDevice(.cpu) {
68+
Device.withDefaultDevice(.gpu) {
69+
let x = MLXArray(1)
70+
let _ = x * x
71+
}
72+
}
73+
}
74+
print("here")
75+
}
76+
77+
func disabledTestCreateStream() {
78+
// see https://github.com/ml-explore/mlx/issues/2118
79+
for _ in 1 ..< 10000 {
80+
let _ = Stream(.cpu)
81+
}
82+
print("here")
83+
}
84+
85+
func disabledTestCreateStreamScoped() {
86+
// see https://github.com/ml-explore/mlx/issues/2118
87+
for _ in 1 ..< 10000 {
88+
Stream.withNewDefaultStream(device: .cpu) {
89+
let x = MLXArray(1)
90+
let _ = x * x
91+
}
92+
}
93+
}
94+
4495
}

0 commit comments

Comments
 (0)