Skip to content

Commit 737f4fa

Browse files
committed
move files around and update sigmoid/softmax to allow for correct multiclass model predictions
1 parent 1fd846a commit 737f4fa

22 files changed

+148
-113
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,6 @@ fastlane/test_output
9090
iOSInjectionProject/
9191

9292
.DS_Store
93+
94+
**/*.mlpackage
95+
**/*.mlmodel

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,11 @@ If you've previously installed the Roboflow SDK via Cocoapods, you'll need to up
178178
The SDK includes a comprehensive test suite that validates model loading and inference functionality. To run the tests:
179179

180180
```bash
181+
# for swift only tests
181182
swift test
183+
184+
# for iOS simulator tests
185+
xcodebuild test -scheme RoboflowTests -destination 'platform=iOS Simulator,arch=arm64,OS=18.5,name=iPhone 16'
182186
```
183187

184188
The test suite includes:

Sources/Roboflow/Classes/Roboflow.swift

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ public class RoboflowMobile: NSObject {
5858
let colors = modelInfo["colors"] as? [String: String],
5959
let classes = modelInfo["classes"] as? [String],
6060
let name = modelInfo["name"] as? String,
61-
let modelType = modelInfo["modelType"] as? String {
61+
let modelType = modelInfo["modelType"] as? String,
62+
let environment = modelInfo["environment"] as? [String: Any] {
6263

6364
getConfigDataBackground(modelName: model, modelVersion: modelVersion, apiKey: apiKey, deviceID: deviceID)
6465

@@ -69,7 +70,7 @@ public class RoboflowMobile: NSObject {
6970
in: .userDomainMask,
7071
appropriateFor: nil,
7172
create: false)
72-
_ = modelObject.loadMLModel(modelPath: documentsURL.appendingPathComponent(modelURL), colors: colors, classes: classes)
73+
_ = modelObject.loadMLModel(modelPath: documentsURL.appendingPathComponent(modelURL), colors: colors, classes: classes, environment: environment)
7374

7475
completion(modelObject, nil, name, modelType)
7576
} catch {
@@ -78,15 +79,14 @@ public class RoboflowMobile: NSObject {
7879
} else if retries > 0 {
7980
clearModelCache(modelName: model, modelVersion: modelVersion)
8081
retries -= 1
81-
getModelData(modelName: model, modelVersion: modelVersion, apiKey: apiKey, deviceID: deviceID) { [self] fetchedModel, error, modelName, modelType, colors, classes in
82+
getModelData(modelName: model, modelVersion: modelVersion, apiKey: apiKey, deviceID: deviceID) { [self] fetchedModel, error, modelName, modelType, colors, classes, environment in
8283
if let err = error {
8384
completion(nil, err, "", "")
8485
} else if let fetchedModel = fetchedModel {
8586
let modelObject = getModelClass(modelType: modelType)
86-
_ = modelObject.loadMLModel(modelPath: fetchedModel, colors: colors ?? [:], classes: classes ?? [])
87+
_ = modelObject.loadMLModel(modelPath: fetchedModel, colors: colors ?? [:], classes: classes ?? [], environment: environment ?? [:])
8788
completion(modelObject, nil, modelName, modelType)
8889
} else {
89-
print("No Model Found. Trying Again.")
9090
clearAndRetryLoadingModel(model, modelVersion, completion)
9191
}
9292
}
@@ -138,7 +138,6 @@ public class RoboflowMobile: NSObject {
138138
completion(dict, nil)
139139

140140
} catch {
141-
print(error.localizedDescription)
142141
completion(nil, error.localizedDescription)
143142
}
144143
}).resume()
@@ -152,10 +151,10 @@ public class RoboflowMobile: NSObject {
152151

153152

154153
//Get the model metadata from the Roboflow API
155-
private func getModelData(modelName: String, modelVersion: Int, apiKey: String, deviceID: String, completion: @escaping (URL?, Error?, String, String, [String: String]?, [String]?)->()) {
154+
private func getModelData(modelName: String, modelVersion: Int, apiKey: String, deviceID: String, completion: @escaping (URL?, Error?, String, String, [String: String]?, [String]?, [String: Any]?)->()) {
156155
getConfigData(modelName: modelName, modelVersion: modelVersion, apiKey: apiKey, deviceID: deviceID) { data, error in
157156
if let error = error {
158-
completion(nil, error, "", "", nil, nil)
157+
completion(nil, error, "", "", nil, nil, nil)
159158
return
160159
}
161160

@@ -165,7 +164,7 @@ public class RoboflowMobile: NSObject {
165164
let modelType = coreMLDict["modelType"] as? String,
166165
let modelURLString = coreMLDict["model"] as? String,
167166
let modelURL = URL(string: modelURLString) else {
168-
completion(nil, error, "", "", nil, nil)
167+
completion(nil, error, "", "", nil, nil, nil)
169168
return
170169
}
171170

@@ -188,28 +187,29 @@ public class RoboflowMobile: NSObject {
188187
//Download the model from the link in the API response
189188
self.downloadModelFile(modelName: "\(modelName)-\(modelVersion).mlmodel", modelVersion: modelVersion, modelURL: modelURL) { fetchedModel, error in
190189
if let error = error {
191-
completion(nil, error, "", "", nil, nil)
190+
completion(nil, error, "", "", nil, nil, nil)
192191
return
193192
}
194193

195194
if let fetchedModel = fetchedModel {
196-
_ = self.cacheModelInfo(modelName: modelName, modelVersion: modelVersion, colors: colors ?? [:], classes: classes ?? [], name: name, modelType: modelType, compiledModelURL: fetchedModel)
197-
completion(fetchedModel, nil, name, modelType, colors, classes)
195+
_ = self.cacheModelInfo(modelName: modelName, modelVersion: modelVersion, colors: colors ?? [:], classes: classes ?? [], name: name, modelType: modelType, compiledModelURL: fetchedModel, environment: environmentDict ?? [:])
196+
completion(fetchedModel, nil, name, modelType, colors, classes, environmentDict)
198197
} else {
199-
completion(nil, error, "", "", nil, nil)
198+
completion(nil, error, "", "", nil, nil, nil)
200199
}
201200
}
202201
}
203202
}
204203

205204

206-
private func cacheModelInfo(modelName: String, modelVersion: Int, colors: [String: String], classes: [String], name: String, modelType: String, compiledModelURL: URL) -> [String: Any]? {
205+
private func cacheModelInfo(modelName: String, modelVersion: Int, colors: [String: String], classes: [String], name: String, modelType: String, compiledModelURL: URL, environment: [String: Any]) -> [String: Any]? {
207206
let modelInfo: [String : Any] = [
208207
"colors": colors,
209208
"classes": classes,
210209
"name": name,
211210
"modelType": modelType,
212-
"compiledModelURL": compiledModelURL.lastPathComponent
211+
"compiledModelURL": compiledModelURL.lastPathComponent,
212+
"environment": environment
213213
]
214214

215215
do {
@@ -227,8 +227,6 @@ public class RoboflowMobile: NSObject {
227227
if let modelInfoData = UserDefaults.standard.data(forKey: "\(modelName)-\(modelVersion)") {
228228
let decodedData = try NSKeyedUnarchiver.unarchivedObject(ofClasses: [NSDictionary.self, NSString.self, NSArray.self], from: modelInfoData) as? [String: Any]
229229
return decodedData
230-
} else {
231-
print("Error: Could not find data for key \(modelName)-\(modelVersion)")
232230
}
233231
} catch {
234232
print("Error unarchiving data: \(error.localizedDescription)")
@@ -257,13 +255,9 @@ public class RoboflowMobile: NSObject {
257255
finalModelURL = try self.unzipModelFile(zipURL: finalModelURL)
258256
}
259257

260-
print("Compiling model at: \(finalModelURL)")
261-
262258
//Compile the downloaded model
263259
let compiledModelURL = try MLModel.compileModel(at: finalModelURL)
264260

265-
print("Model compiled to: \(compiledModelURL)")
266-
267261
// Ensure Documents directory exists
268262
let documentsURL = try FileManager.default.url(for: .documentDirectory,
269263
in: .userDomainMask,
@@ -272,13 +266,9 @@ public class RoboflowMobile: NSObject {
272266

273267
let savedURL = documentsURL.appendingPathComponent("\(modelName)-\(modelVersion).mlmodelc")
274268

275-
print("Attempting to move from: \(compiledModelURL)")
276-
print("To: \(savedURL)")
277-
278269
// Check if the compiled model exists
279270
guard FileManager.default.fileExists(atPath: compiledModelURL.path) else {
280271
let error = NSError(domain: "ModelCompilationError", code: 1, userInfo: [NSLocalizedDescriptionKey: "Compiled model does not exist at: \(compiledModelURL.path)"])
281-
print("Error: \(error.localizedDescription)")
282272
completion(nil, error)
283273
return
284274
}
@@ -291,15 +281,11 @@ public class RoboflowMobile: NSObject {
291281
// Move the compiled model
292282
do {
293283
try FileManager.default.moveItem(at: compiledModelURL, to: savedURL)
294-
print("Successfully moved model to: \(savedURL)")
295284
} catch {
296-
print("Move failed with error: \(error.localizedDescription)")
297285
// If move fails, try copying instead
298286
do {
299287
try FileManager.default.copyItem(at: compiledModelURL, to: savedURL)
300-
print("Successfully copied model to: \(savedURL)")
301288
} catch {
302-
print("Copy also failed: \(error.localizedDescription)")
303289
completion(nil, error)
304290
return
305291
}

Sources/Roboflow/Classes/RFClassificationModel.swift renamed to Sources/Roboflow/Classes/core/RFClassificationModel.swift

Lines changed: 13 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -17,60 +17,19 @@ public class RFClassificationModel: RFModel {
1717
}
1818

1919
//Default model configuration parameters
20-
var threshold: Double = 0.5
21-
var classes: [String] = []
22-
23-
//Configure the parameters for the model
24-
public override func configure(threshold: Double, overlap: Double, maxObjects: Float, processingMode: ProcessingMode = .balanced, maxNumberPoints: Int = 500) {
25-
self.threshold = threshold
26-
}
20+
var multiclass: Bool = false
2721

2822
//Load the retrieved CoreML model into an already created RFClassificationModel instance
29-
override func loadMLModel(modelPath: URL, colors: [String: String], classes: [String]) -> Error? {
30-
self.classes = classes
31-
do {
32-
if #available(macOS 10.14, *) {
33-
let config = MLModelConfiguration()
34-
if #available(macOS 10.15, *) {
35-
mlModel = try MLModel(contentsOf: modelPath, configuration: config)
36-
} else {
37-
// Fallback on earlier versions
38-
return UnsupportedOSError()
39-
}
40-
visionModel = try VNCoreMLModel(for: mlModel)
41-
let request = VNCoreMLRequest(model: visionModel)
42-
request.imageCropAndScaleOption = .scaleFill
43-
coreMLRequest = request
44-
} else {
45-
// Fallback on earlier versions
46-
return UnsupportedOSError()
47-
}
48-
49-
} catch {
50-
return error
23+
override func loadMLModel(modelPath: URL, colors: [String: String], classes: [String], environment: [String: Any]) -> Error? {
24+
let _ = super.loadMLModel(modelPath: modelPath, colors: colors, classes: classes, environment: environment)
25+
if let _ = environment["MULTICLASS"] {
26+
self.multiclass = true
5127
}
52-
return nil
53-
}
54-
55-
//Load a local model file (for manually placed models like ResNet)
56-
public func loadLocalModel(modelPath: URL) -> Error? {
5728
do {
5829
if #available(macOS 10.14, *) {
5930
let config = MLModelConfiguration()
6031
if #available(macOS 10.15, *) {
61-
var modelURL = modelPath
62-
63-
// If the model is .mlpackage, compile it first
64-
if modelPath.pathExtension == "mlpackage" {
65-
do {
66-
let compiledModelURL = try MLModel.compileModel(at: modelPath)
67-
modelURL = compiledModelURL
68-
} catch {
69-
return error
70-
}
71-
}
72-
73-
mlModel = try MLModel(contentsOf: modelURL, configuration: config)
32+
mlModel = try MLModel(contentsOf: modelPath, configuration: config)
7433
} else {
7534
// Fallback on earlier versions
7635
return UnsupportedOSError()
@@ -141,20 +100,12 @@ public class RFClassificationModel: RFModel {
141100
let rawValues = multiArray.dataPointer.bindMemory(to: Float.self, capacity: multiArray.count)
142101

143102
// Check if values are logits (outside 0-1 range) and need softmax
144-
var needsSoftmax = false
145-
for i in 0..<multiArray.count {
146-
if rawValues[i] < 0.0 || rawValues[i] > 1.0 {
147-
needsSoftmax = true
148-
break
149-
}
150-
}
151-
152103
let probabilities: [Float]
153-
if needsSoftmax {
104+
if !multiclass {
154105
// Apply softmax to convert logits to probabilities
155106
probabilities = applySoftmax(logits: rawValues, count: multiArray.count)
156107
} else {
157-
// Values are already probabilities
108+
// Values are already probabilities (sigmoid applied in model)
158109
probabilities = Array(UnsafeBufferPointer(start: rawValues, count: multiArray.count))
159110
}
160111

@@ -183,7 +134,11 @@ public class RFClassificationModel: RFModel {
183134
return pred1.confidence > pred2.confidence
184135
}
185136

186-
completion(predictions, nil)
137+
if multiclass {
138+
completion(predictions, nil)
139+
} else {
140+
completion(predictions.isEmpty ? [] : [predictions[0]], nil)
141+
}
187142
} catch let error {
188143
completion(nil, error)
189144
}

Sources/Roboflow/Classes/RFInstanceSegmentationModel.swift renamed to Sources/Roboflow/Classes/core/RFInstanceSegmentationModel.swift

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,11 @@ import Accelerate
1414

1515
//Creates an instance of an ML model that's hosted on Roboflow
1616
public class RFInstanceSegmentationModel: RFObjectDetectionModel {
17-
var classes = [String]()
18-
var maskProcessingMode: ProcessingMode = .balanced
19-
var maskMaxNumberPoints: Int = 500
20-
21-
public override func configure(threshold: Double, overlap: Double, maxObjects: Float, processingMode: ProcessingMode = .balanced, maxNumberPoints: Int = 500) {
22-
super.configure(threshold: threshold, overlap: overlap, maxObjects: maxObjects, processingMode: processingMode)
23-
maskProcessingMode = processingMode
24-
maskMaxNumberPoints = maxNumberPoints
25-
}
2617

2718

2819
//Load the retrieved CoreML model into an already created RFObjectDetectionModel instance
29-
override func loadMLModel(modelPath: URL, colors: [String: String], classes: [String]) -> Error? {
30-
self.colors = colors
31-
self.classes = classes
20+
override func loadMLModel(modelPath: URL, colors: [String: String], classes: [String], environment: [String: Any]) -> Error? {
21+
let _ = super.loadMLModel(modelPath: modelPath, colors: colors, classes: classes, environment: environment)
3222
do {
3323

3424
if #available(iOS 16.0, macOS 13.0, *) {

Sources/Roboflow/Classes/RFModel.swift renamed to Sources/Roboflow/Classes/core/RFModel.swift

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,31 @@ public class RFModel: NSObject {
2626
var mlModel: MLModel!
2727
var visionModel: VNCoreMLModel!
2828
var coreMLRequest: VNCoreMLRequest!
29-
29+
var environment: [String: Any]!
30+
var modelPath: URL!
31+
var colors: [String: String]!
32+
var classes: [String]!
33+
var threshold: Double = 0.5
34+
var overlap: Double = 0.4
35+
var maxObjects: Float = 20
36+
var maskProcessingMode: ProcessingMode = .balanced
37+
var maskMaxNumberPoints: Int = 500
38+
3039
//Configure the parameters for the model
31-
public func configure(threshold: Double, overlap: Double, maxObjects: Float, processingMode: ProcessingMode = .balanced, maxNumberPoints: Int = 500) {}
40+
public func configure(threshold: Double = 0.5, overlap: Double = 0.5, maxObjects: Float = 20, processingMode: ProcessingMode = .balanced, maxNumberPoints: Int = 500) {
41+
self.threshold = threshold
42+
self.overlap = overlap
43+
self.maxObjects = maxObjects
44+
self.maskProcessingMode = processingMode
45+
self.maskMaxNumberPoints = maxNumberPoints
46+
}
3247

3348
//Load the retrieved CoreML model into an already created RFObjectDetectionModel instance
34-
func loadMLModel(modelPath: URL, colors: [String: String], classes: [String]) -> Error? {
49+
func loadMLModel(modelPath: URL, colors: [String: String], classes: [String], environment: [String: Any]) -> Error? {
50+
self.environment = environment
51+
self.modelPath = modelPath
52+
self.colors = colors
53+
self.classes = classes
3554
return nil
3655
}
3756

Sources/Roboflow/Classes/RFObjectDetectionModel.swift renamed to Sources/Roboflow/Classes/core/RFObjectDetectionModel.swift

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,12 @@ public class RFObjectDetectionModel: RFModel {
1515
super.init()
1616
}
1717

18-
//Default model configuration parameters
19-
var threshold: Double = 0.5
20-
var overlap: Double = 0.5
21-
var maxObjects: Float = 20.0
22-
var colors: [String: String]!
2318
//Stores the retreived ML model
2419
var thresholdProvider = ThresholdProvider()
2520

2621
//Configure the parameters for the model
27-
public override func configure(threshold: Double, overlap: Double, maxObjects: Float, processingMode: ProcessingMode = .balanced, maxNumberPoints: Int = 500) {
28-
self.threshold = threshold
29-
self.overlap = overlap
30-
self.maxObjects = maxObjects
22+
public override func configure(threshold: Double = 0.5, overlap: Double = 0.5, maxObjects: Float = 20, processingMode: ProcessingMode = .balanced, maxNumberPoints: Int = 500) {
23+
super.configure(threshold: threshold, overlap: overlap, maxObjects: maxObjects, processingMode: processingMode, maxNumberPoints: maxNumberPoints)
3124
thresholdProvider.values = ["iouThreshold": MLFeatureValue(double: self.overlap),
3225
"confidenceThreshold": MLFeatureValue(double: self.threshold)]
3326
if visionModel != nil {
@@ -40,8 +33,8 @@ public class RFObjectDetectionModel: RFModel {
4033
}
4134

4235
//Load the retrieved CoreML model into an already created RFObjectDetectionModel instance
43-
override func loadMLModel(modelPath: URL, colors: [String: String], classes: [String]) -> Error? {
44-
self.colors = colors
36+
override func loadMLModel(modelPath: URL, colors: [String: String], classes: [String], environment: [String: Any]) -> Error? {
37+
super.loadMLModel(modelPath: modelPath, colors: colors, classes: classes, environment: environment)
4538
do {
4639
if #available(macOS 10.14, *) {
4740
let config = MLModelConfiguration()

0 commit comments

Comments
 (0)