Skip to content

Commit fb2b5c8

Browse files
Merge pull request #13 from roboflow/bugfix-incorrect-seg-model-output-ordering
Update RFInstanceSegmentationModel.swift
2 parents 6c5ba16 + d2f4657 commit fb2b5c8

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

Sources/Roboflow/Classes/RFInstanceSegmentationModel.swift

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,20 @@ public class RFInstanceSegmentationModel: RFObjectDetectionModel {
6767
do {
6868
try handler.perform([coreMLRequest])
6969
guard let detectResults = coreMLRequest.results else { return }
70+
71+
let castDetectResults0 = (detectResults[0] as! VNCoreMLFeatureValueObservation).featureValue.multiArrayValue!
72+
let castDetectResults1 = (detectResults[1] as! VNCoreMLFeatureValueObservation).featureValue.multiArrayValue!
7073

71-
let predictions = detectResults[1] as! VNCoreMLFeatureValueObservation
72-
let protos = detectResults[0] as! VNCoreMLFeatureValueObservation
73-
74-
let pred = predictions.featureValue.multiArrayValue!
75-
let proto = protos.featureValue.multiArrayValue!
74+
let pred = castDetectResults0.shape.count == 3 ? castDetectResults0 : castDetectResults1
75+
let proto = castDetectResults1.shape.count == 4 ? castDetectResults1 : castDetectResults0
7676

7777
let numMasks = 32
7878
let numCls = self.colors.count
7979

8080
// --- flatten MLMultiArray to Swift [Float] for speed
8181
let p = pred.dataPointer.bindMemory(to: Float.self,
8282
capacity: pred.count)
83+
8384
let preds = UnsafeBufferPointer(start: p, count: pred.count)
8485
let protoShape = (c:Int(truncating: proto.shape[1]),
8586
h:Int(truncating: proto.shape[2]),

0 commit comments

Comments
 (0)