@@ -28,9 +28,19 @@ public enum DeviceType: String, Hashable, Sendable {
2828public final class Device : @unchecked Sendable , Equatable {
2929
3030 let ctx : mlx_device
31+ let defaultStream : Stream
3132
3233 init ( _ ctx: mlx_device ) {
3334 self . ctx = ctx
35+
36+ var deviceType = MLX_GPU
37+ mlx_device_get_type ( & deviceType, ctx)
38+ self . defaultStream =
39+ switch deviceType {
40+ case MLX_CPU: . cpu
41+ case MLX_GPU: . gpu
42+ default : . gpu
43+ }
3444 }
3545
3646 public init ( _ deviceType: DeviceType , index: Int32 = 0 ) {
@@ -42,19 +52,32 @@ public final class Device: @unchecked Sendable, Equatable {
4252 cDeviceType = MLX_GPU
4353 }
4454 self . ctx = mlx_device_new_type ( cDeviceType, index)
55+ self . defaultStream =
56+ switch deviceType {
57+ case . cpu: . cpu
58+ case . gpu: . gpu
59+ }
4560 }
4661
47- public init ( ) {
62+ @available ( * , deprecated, message: " please use defaultDevice() " )
63+ public convenience init ( ) {
4864 var ctx = mlx_device_new ( )
4965 mlx_get_default_device ( & ctx)
50- self . ctx = ctx
66+ self . init ( ctx)
5167 }
5268
5369 deinit {
5470 mlx_device_free ( ctx)
5571 }
5672
73+ /// static CPU device
74+ ///
75+ /// See ``withDefaultDevice(_:_:)``
5776 static public let cpu : Device = Device ( . cpu)
77+
78+ /// static GPU device
79+ ///
80+ /// See ``withDefaultDevice(_:_:)``
5881 static public let gpu : Device = Device ( . gpu)
5982
6083 public var deviceType : DeviceType ? {
@@ -67,28 +90,50 @@ public final class Device: @unchecked Sendable, Equatable {
6790 }
6891 }
6992
93+ // support for global default device
7094 static let _lock = NSLock ( )
7195 #if swift(>=5.10)
72- nonisolated ( unsafe) static var _defaultDevice = gpu
73- nonisolated ( unsafe) static var _defaultStream = Stream ( gpu)
96+ nonisolated ( unsafe) static var _defaultDevice : Device ?
7497 #else
75- static var _defaultDevice = gpu
76- static var _defaultStream = Stream ( gpu)
98+ static var _defaultDevice : Device ?
7799 #endif
78100
79- static public func defaultDevice( ) -> Device {
101+ @TaskLocal static var _tlDefaultDevice = _resolveGlobalDefaultDevice ( )
102+
103+ private static func _resolveGlobalDefaultDevice( ) -> Device {
80104 _lock. withLock {
81- _defaultDevice
105+ _defaultDevice ?? . gpu
82106 }
83107 }
84108
109+ /// Return the current default device.
110+ ///
111+ /// This is used by ``StreamOrDevice/default`` -- the default stream parameter
112+ /// to most functions.
113+ static public func defaultDevice( ) -> Device {
114+ _tlDefaultDevice
115+ }
116+
117+ /// Use a device scoped to a task.
118+ static public func withDefaultDevice< R> (
119+ _ device: Device , _ body: ( ) throws -> R
120+ ) rethrows -> R {
121+ try $_tlDefaultDevice. withValue ( device, operation: body)
122+ }
123+
124+ /// Use a device scoped to a task.
125+ static public func withDefaultDevice< R> (
126+ _ device: Device , _ body: ( ) async throws -> R
127+ ) async rethrows -> R {
128+ try await $_tlDefaultDevice. withValue ( device, operation: body)
129+ }
130+
131+ /// Return the current default stream.
85132 static func defaultStream( ) -> Stream {
86- _lock. withLock {
87- _defaultStream
88- }
133+ _tlDefaultDevice. defaultStream
89134 }
90135
91- /// Set the default device.
136+ /// Set the default device globally. Prefer the scoped version, ``withDefaultDevice(_:_:)`` .
92137 ///
93138 /// For example:
94139 ///
@@ -99,12 +144,19 @@ public final class Device: @unchecked Sendable, Equatable {
99144 /// By default this is ``gpu``.
100145 ///
101146 /// ### See Also
147+ /// - ``withDefaultDevice(_:_:)``
102148 /// - ``StreamOrDevice/default``
103- static public func setDefault( device: Device ) {
149+ @available ( * , deprecated, message: " please use withDefaultDevice() " )
150+ static public func setDefault( device: Device ? ) {
104151 _lock. withLock {
105- mlx_set_default_device ( device. ctx)
152+ if let device {
153+ // sets the mlx core default device -- only used
154+ // by the deprecated init(). this isn't thread
155+ // safe or really usable across tasks/threads
156+ // but is kept for backward compatibility
157+ mlx_set_default_device ( device. ctx)
158+ }
106159 _defaultDevice = device
107- _defaultStream = Stream ( device)
108160 }
109161 }
110162
@@ -139,6 +191,7 @@ extension Device: CustomStringConvertible {
139191/// - Parameters:
140192/// - device: device to be used
141193/// - fn: function to be executed
194+ @available ( * , deprecated, message: " please use Device.withDefaultDevice() " )
142195public func using< R> ( device: Device , fn: ( ) throws -> R ) rethrows -> R {
143- try Stream . withNewDefaultStream ( device : device, fn)
196+ try Device . withDefaultDevice ( device, fn)
144197}
0 commit comments