Skip to content

Commit 96c0f69

Browse files
authored
add path information for parameter update errors -- great help in debugging (#252)
* add path information for parameter update errors -- great help in debugging * conv2d missing divide by groups
1 parent f249ad7 commit 96c0f69

File tree

3 files changed

+84
-46
lines changed

3 files changed

+84
-46
lines changed

Source/MLXNN/Convolution.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ open class Conv1d: Module, UnaryLayer {
4747
) {
4848
let scale = sqrt(1 / Float(inputChannels * kernelSize))
4949

50+
precondition(inputChannels % groups == 0, "Input channels must be divisible by groups")
51+
5052
self.weight = MLXRandom.uniform(
5153
low: -scale, high: scale, [outputChannels, kernelSize, inputChannels / groups])
5254
self.bias = bias ? MLXArray.zeros([outputChannels]) : nil
@@ -111,9 +113,11 @@ open class Conv2d: Module, UnaryLayer {
111113
) {
112114
let scale = sqrt(1 / Float(inputChannels * kernelSize.first * kernelSize.second))
113115

116+
precondition(inputChannels % groups == 0, "Input channels must be divisible by groups")
117+
114118
self.weight = MLXRandom.uniform(
115119
low: -scale, high: scale,
116-
[outputChannels, kernelSize.first, kernelSize.second, inputChannels])
120+
[outputChannels, kernelSize.first, kernelSize.second, inputChannels / groups])
117121
self.bias = bias ? MLXArray.zeros([outputChannels]) : nil
118122
self.padding = padding.values
119123
self.dilation = dilation.values

Source/MLXNN/Module.swift

Lines changed: 72 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ open class Module {
387387
static public let noUnusedKeys = VerifyUpdate(rawValue: 1 << 0)
388388

389389
static public let allModelKeysSet = VerifyUpdate(rawValue: 1 << 1)
390+
static public let shapeMismatch = VerifyUpdate(rawValue: 1 << 2)
390391

391392
static public let all = VerifyUpdate(rawValue: -1)
392393
static public let none = VerifyUpdate([])
@@ -433,9 +434,16 @@ open class Module {
433434
/// - ``mapParameters(map:isLeaf:)``
434435
/// - ``update(modules:verify:)``
435436
@discardableResult
436-
open func update(parameters: ModuleParameters, verify: VerifyUpdate) throws -> Self {
437+
open func update(
438+
parameters: ModuleParameters, verify: VerifyUpdate, path: [String] = [],
439+
modulePath: [String] = []
440+
) throws -> Self {
437441

438-
func apply(key: String, _ item: ModuleItem, _ value: NestedItem<String, MLXArray>) throws {
442+
let modulePath = modulePath + [describeType(self)]
443+
444+
func apply(
445+
key: String, path: [String], _ item: ModuleItem, _ value: NestedItem<String, MLXArray>
446+
) throws {
439447
if case .none = value, !verify.contains(.allModelKeysSet) {
440448
return
441449
}
@@ -447,75 +455,86 @@ open class Module {
447455

448456
switch (item, value) {
449457
case (.value(.parameters(let p)), .value(let newArray)):
450-
if verify.contains(.all), p.shape != newArray.shape {
458+
if verify.contains(.shapeMismatch), p.shape != newArray.shape {
451459
throw UpdateError.mismatchedSize(
452-
key: key, expectedShape: p.shape, actualShape: newArray.shape)
460+
path: path, modules: modulePath, expectedShape: p.shape,
461+
actualShape: newArray.shape)
453462
}
454463
p._updateInternal(newArray)
455464

456465
case (.value(.parameters(let p)), .none):
457466
if Self.parameterIsValid(key) {
458-
throw UpdateError.keyNotFound(base: describeType(self), key: key)
467+
throw UpdateError.keyNotFound(path: path, modules: modulePath)
459468
} else {
460469
// ignore it -- this isn't a parameter that requires update
461470
}
462471

463472
case (.array(let array), .array(let values)):
464473
for (i, (arrayItem, valueItem)) in zip(array, values).enumerated() {
465-
try apply(key: "\(key).\(i)", arrayItem, valueItem)
474+
try apply(key: "\(key).\(i)", path: path + ["\(i)"], arrayItem, valueItem)
466475
}
467476
if verify.contains(.allModelKeysSet) {
468477
for i in values.count ..< array.count {
469-
try apply(key: "\(key).\(i)", array[i], .none)
478+
try apply(key: "\(key).\(i)", path: path + ["\(i)"], array[i], .none)
470479
}
471480
}
472481

473482
case (.array(let array), .none):
474483
for (i, arrayItem) in array.enumerated() {
475-
try apply(key: "\(key).\(i)", arrayItem, .none)
484+
try apply(key: "\(key).\(i)", path: path + ["\(i)"], arrayItem, .none)
476485
}
477486

478487
case (.dictionary(let dictionary), .dictionary(let values)):
479488
for (dictionaryKey, dictionaryItem) in dictionary {
489+
let newKey = "\(key).\(dictionaryKey)"
490+
let path = path + [dictionaryKey]
480491
if let valueItem = values[key] {
481-
try apply(key: "\(key).\(dictionaryKey)", dictionaryItem, valueItem)
492+
try apply(key: newKey, path: path, dictionaryItem, valueItem)
482493
} else if verify.contains(.allModelKeysSet) {
483-
try apply(key: "\(key).\(dictionaryKey)", dictionaryItem, .none)
494+
try apply(key: newKey, path: path, dictionaryItem, .none)
484495
}
485496
}
486497

487498
case (.dictionary(let dictionary), .none):
488499
for (dictionaryKey, dictionaryItem) in dictionary {
489-
try apply(key: "\(key).\(dictionaryKey)", dictionaryItem, .none)
500+
let newKey = "\(key).\(dictionaryKey)"
501+
let path = path + [dictionaryKey]
502+
try apply(key: newKey, path: path, dictionaryItem, .none)
490503
}
491504

492505
case (.value(.module(let module)), .dictionary(let values)):
493-
try module.update(parameters: NestedDictionary(values: values), verify: verify)
506+
try module.update(
507+
parameters: NestedDictionary(values: values), verify: verify, path: path,
508+
modulePath: modulePath)
494509

495510
case (.value(.module(let module)), .none):
496-
try module.update(parameters: NestedDictionary(), verify: verify)
511+
try module.update(
512+
parameters: NestedDictionary(), verify: verify, path: path,
513+
modulePath: modulePath)
497514

498515
case (.none, .none), (.value(.none), .none), (.value(.other(_)), .none):
499516
break
500517

501518
default:
502-
fatalError("Unable to set \(key) on \(self): \(item) not compatible with \(value)")
519+
fatalError(
520+
"Unable to set \(path.joined(separator: ".")) on \(modulePath.joined(separator: ".")): \(item) not compatible with \(value.mapValues { $0.shape.description })"
521+
)
503522
}
504523
}
505524

506525
var processed = Set(parameters.keys)
507526
for (key, item) in items() {
508527
if let value = parameters[key] {
509528
processed.remove(key)
510-
try apply(key: key, item, value)
529+
try apply(key: key, path: path + [key], item, value)
511530
} else if verify.contains(.allModelKeysSet) {
512-
try apply(key: key, item, .none)
531+
try apply(key: key, path: path + [key], item, .none)
513532
}
514533
}
515534

516535
if verify.contains(.noUnusedKeys) && !processed.isEmpty {
517536
throw UpdateError.unhandledKeys(
518-
base: describeType(self), keys: processed.sorted())
537+
path: path, modules: modulePath, keys: processed.sorted())
519538
}
520539

521540
return self
@@ -594,9 +613,16 @@ open class Module {
594613
/// - ``leafModules()``
595614
/// - ``QuantizedLinear/quantize(model:groupSize:bits:predicate:)``
596615
@discardableResult
597-
open func update(modules: ModuleChildren, verify: VerifyUpdate) throws -> Self {
616+
open func update(
617+
modules: ModuleChildren, verify: VerifyUpdate, path: [String] = [],
618+
modulePath: [String] = []
619+
) throws -> Self {
598620

599-
func apply(key: String, _ item: ModuleItem, _ value: NestedItem<String, Module>) throws {
621+
let modulePath = modulePath + [describeType(self)]
622+
623+
func apply(
624+
key: String, path: [String], _ item: ModuleItem, _ value: NestedItem<String, Module>
625+
) throws {
600626
// item: single item from `items()`
601627
// value: single item with matching structure from `children()`
602628
//
@@ -605,7 +631,7 @@ open class Module {
605631
switch (item, value) {
606632
case (.value(.parameters), .value):
607633
fatalError(
608-
"Unable to set \(key) on \(self): parameters (MLXArray) cannot be updated with a Module"
634+
"Unable to set \(path.joined(separator: ".")) on \(modulePath.joined(separator: ".")): parameters (MLXArray) cannot be updated with a Module"
609635
)
610636

611637
case (.array(let items), .array(let values)):
@@ -634,17 +660,17 @@ open class Module {
634660
default:
635661
// otherwise we don't know how to update it
636662
throw UpdateError.unableToCollectModulesFromContainer(
637-
base: describeType(self), key: key)
663+
path: path, modules: modulePath)
638664
}
639665
} else {
640666
// past the end of items
641667
throw UpdateError.unableToCollectModulesFromContainer(
642-
base: describeType(self), key: key)
668+
path: path, modules: modulePath)
643669
}
644670

645671
default:
646672
throw UpdateError.unableToCollectModulesFromContainer(
647-
base: describeType(self), key: key)
673+
path: path, modules: modulePath)
648674
}
649675
}
650676

@@ -678,7 +704,7 @@ open class Module {
678704
newModules[item.key] = module
679705
default:
680706
throw UpdateError.unableToCollectModulesFromContainer(
681-
base: describeType(self), key: key)
707+
path: path, modules: modulePath)
682708
}
683709
}
684710

@@ -697,21 +723,23 @@ open class Module {
697723
try module.update(modules: NestedDictionary(values: values), verify: verify)
698724

699725
default:
700-
fatalError("Unable to set \(key) on \(self): \(item) not compatible with \(value)")
726+
fatalError(
727+
"Unable to set \(path.joined(separator: ".")) on \(modulePath.joined(separator: ".")): \(item) not compatible with \(value)"
728+
)
701729
}
702730
}
703731

704732
var processed = Set(modules.keys)
705733
for (key, item) in items() {
706734
if let value = modules[key] {
707735
processed.remove(key)
708-
try apply(key: key, item, value)
736+
try apply(key: key, path: path + [key], item, value)
709737
}
710738
}
711739

712740
if verify.contains(.noUnusedKeys) && !processed.isEmpty {
713741
throw UpdateError.unhandledKeys(
714-
base: describeType(self), keys: processed.sorted())
742+
path: path, modules: modulePath, keys: processed.sorted())
715743
}
716744

717745
// rebuild the caches because the modules may have changed
@@ -806,7 +834,7 @@ open class Module {
806834
let localKeys = Set(localKeys)
807835
for key in keys {
808836
if !localKeys.contains(key) {
809-
throw UpdateError.keyNotFound(base: describeType(self), key: key)
837+
throw UpdateError.keyNotFound(path: [key], modules: [describeType(self)])
810838
}
811839
}
812840
}
@@ -1537,36 +1565,40 @@ private protocol TypeErasedSetterProvider {
15371565
}
15381566

15391567
enum UpdateError: Error {
1540-
case unableToCollectModulesFromContainer(base: String, key: String)
1568+
case unableToCollectModulesFromContainer(path: [String], modules: [String])
15411569
case mismatchedContainers(base: String, key: String)
1542-
case mismatchedSize(key: String, expectedShape: [Int], actualShape: [Int])
1543-
case keyNotFound(base: String, key: String)
1570+
case mismatchedSize(path: [String], modules: [String], expectedShape: [Int], actualShape: [Int])
1571+
case keyNotFound(path: [String], modules: [String])
15441572
case needModuleInfo(String)
15451573
case unableToSet(String)
15461574
case unableToCast(String)
1547-
case unhandledKeys(base: String, keys: [String])
1575+
case unhandledKeys(path: [String], modules: [String], keys: [String])
15481576
}
15491577

15501578
extension UpdateError: LocalizedError {
15511579
var errorDescription: String? {
15521580
switch self {
1553-
case .unableToCollectModulesFromContainer(let base, let key):
1554-
return "Unable to collect modules from container: \(base) \(key)"
1581+
case .unableToCollectModulesFromContainer(let path, let modules):
1582+
return
1583+
"Unable to collect modules from container: \(path.joined(separator: ".")) in \(modules.joined(separator: "."))"
15551584
case .mismatchedContainers(let base, let key):
15561585
return "Mismatched containers: \(base) \(key)"
1557-
case let .mismatchedSize(key: key, expectedShape: expectedShape, actualShape: actualShape):
1586+
case let .mismatchedSize(
1587+
path, modules, expectedShape: expectedShape, actualShape: actualShape):
15581588
return
1559-
"Mismatched parameter \(key) shape. Actual \(actualShape), expected \(expectedShape)"
1560-
case .keyNotFound(let base, let key):
1561-
return "Key \(key) not found in \(base)"
1589+
"Mismatched parameter \(path.joined(separator: ".")) in \(modules.joined(separator: ".")) shape. Actual \(actualShape), expected \(expectedShape)"
1590+
case .keyNotFound(let path, let modules):
1591+
return
1592+
"Key \(path.joined(separator: ".")) not found in \(modules.joined(separator: "."))"
15621593
case .needModuleInfo(let string):
15631594
return string
15641595
case .unableToSet(let string):
15651596
return string
15661597
case .unableToCast:
15671598
return "Unable to cast value"
1568-
case .unhandledKeys(let base, let keys):
1569-
return "Unhandled keys \(keys) in \(base)"
1599+
case .unhandledKeys(let path, let modules, let keys):
1600+
return
1601+
"Unhandled keys \(keys) in \(path.joined(separator: ".")) in \(modules.joined(separator: "."))"
15701602
}
15711603
}
15721604
}

Tests/MLXTests/ModuleTests.swift

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -555,12 +555,14 @@ class ModuleTests: XCTestCase {
555555
verify: .all)
556556
) { error in
557557
guard let error = error as? UpdateError,
558-
case let .keyNotFound(base: base, key: key) = error
558+
case let .keyNotFound(path, modules) = error
559559
else {
560560
XCTFail("Expected to fail with UpdateError.keyNotFound, but got: \(error)")
561561
return
562562
}
563-
XCTAssertEqual(key, "bias")
563+
// should be a.bias or b.bias (random order as it is a dict)
564+
XCTAssertEqual(path.last, "bias")
565+
XCTAssertEqual(modules, ["M", "Linear"])
564566
}
565567
}
566568

@@ -585,18 +587,18 @@ class ModuleTests: XCTestCase {
585587
) { error in
586588
guard let error = error as? UpdateError,
587589
case let .mismatchedSize(
588-
key: key, expectedShape: expectedShape, actualShape: actualShape) =
590+
path, modules, expectedShape: expectedShape, actualShape: actualShape) =
589591
error
590592
else {
591593
XCTFail("Expected to fail with UpdateError.mismatchedSize, but got: \(error)")
592594
return
593595
}
594596
XCTAssertEqual(expectedShape, [2, 1])
595597
XCTAssertEqual(actualShape, [1, 2])
596-
XCTAssertEqual(key, "weight")
598+
XCTAssertEqual(path, ["weight"])
597599
XCTAssertEqual(
598600
error.errorDescription,
599-
"Mismatched parameter weight shape. Actual [1, 2], expected [2, 1]")
601+
"Mismatched parameter weight in Linear shape. Actual [1, 2], expected [2, 1]")
600602
}
601603
}
602604

0 commit comments

Comments
 (0)