@@ -387,6 +387,7 @@ open class Module {
387387 static public let noUnusedKeys = VerifyUpdate ( rawValue: 1 << 0 )
388388
389389 static public let allModelKeysSet = VerifyUpdate ( rawValue: 1 << 1 )
390+ static public let shapeMismatch = VerifyUpdate ( rawValue: 1 << 2 )
390391
391392 static public let all = VerifyUpdate ( rawValue: - 1 )
392393 static public let none = VerifyUpdate ( [ ] )
@@ -433,9 +434,16 @@ open class Module {
433434 /// - ``mapParameters(map:isLeaf:)``
434435 /// - ``update(modules:verify:)``
435436 @discardableResult
436- open func update( parameters: ModuleParameters , verify: VerifyUpdate ) throws -> Self {
437+ open func update(
438+ parameters: ModuleParameters , verify: VerifyUpdate , path: [ String ] = [ ] ,
439+ modulePath: [ String ] = [ ]
440+ ) throws -> Self {
437441
438- func apply( key: String , _ item: ModuleItem , _ value: NestedItem < String , MLXArray > ) throws {
442+ let modulePath = modulePath + [ describeType ( self ) ]
443+
444+ func apply(
445+ key: String , path: [ String ] , _ item: ModuleItem , _ value: NestedItem < String , MLXArray >
446+ ) throws {
439447 if case . none = value, !verify. contains ( . allModelKeysSet) {
440448 return
441449 }
@@ -447,75 +455,86 @@ open class Module {
447455
448456 switch ( item, value) {
449457 case ( . value( . parameters( let p) ) , . value( let newArray) ) :
450- if verify. contains ( . all ) , p. shape != newArray. shape {
458+ if verify. contains ( . shapeMismatch ) , p. shape != newArray. shape {
451459 throw UpdateError . mismatchedSize (
452- key: key, expectedShape: p. shape, actualShape: newArray. shape)
460+ path: path, modules: modulePath, expectedShape: p. shape,
461+ actualShape: newArray. shape)
453462 }
454463 p. _updateInternal ( newArray)
455464
456465 case ( . value( . parameters( let p) ) , . none) :
457466 if Self . parameterIsValid ( key) {
458- throw UpdateError . keyNotFound ( base : describeType ( self ) , key : key )
467+ throw UpdateError . keyNotFound ( path : path , modules : modulePath )
459468 } else {
460469 // ignore it -- this isn't a parameter that requires update
461470 }
462471
463472 case ( . array( let array) , . array( let values) ) :
464473 for (i, ( arrayItem, valueItem) ) in zip ( array, values) . enumerated ( ) {
465- try apply ( key: " \( key) . \( i) " , arrayItem, valueItem)
474+ try apply ( key: " \( key) . \( i) " , path : path + [ " \( i ) " ] , arrayItem, valueItem)
466475 }
467476 if verify. contains ( . allModelKeysSet) {
468477 for i in values. count ..< array. count {
469- try apply ( key: " \( key) . \( i) " , array [ i] , . none)
478+ try apply ( key: " \( key) . \( i) " , path : path + [ " \( i ) " ] , array [ i] , . none)
470479 }
471480 }
472481
473482 case ( . array( let array) , . none) :
474483 for (i, arrayItem) in array. enumerated ( ) {
475- try apply ( key: " \( key) . \( i) " , arrayItem, . none)
484+ try apply ( key: " \( key) . \( i) " , path : path + [ " \( i ) " ] , arrayItem, . none)
476485 }
477486
478487 case ( . dictionary( let dictionary) , . dictionary( let values) ) :
479488 for (dictionaryKey, dictionaryItem) in dictionary {
489+ let newKey = " \( key) . \( dictionaryKey) "
490+ let path = path + [ dictionaryKey]
480491 if let valueItem = values [ key] {
481- try apply ( key: " \( key ) . \( dictionaryKey ) " , dictionaryItem, valueItem)
492+ try apply ( key: newKey , path : path , dictionaryItem, valueItem)
482493 } else if verify. contains ( . allModelKeysSet) {
483- try apply ( key: " \( key ) . \( dictionaryKey ) " , dictionaryItem, . none)
494+ try apply ( key: newKey , path : path , dictionaryItem, . none)
484495 }
485496 }
486497
487498 case ( . dictionary( let dictionary) , . none) :
488499 for (dictionaryKey, dictionaryItem) in dictionary {
489- try apply ( key: " \( key) . \( dictionaryKey) " , dictionaryItem, . none)
500+ let newKey = " \( key) . \( dictionaryKey) "
501+ let path = path + [ dictionaryKey]
502+ try apply ( key: newKey, path: path, dictionaryItem, . none)
490503 }
491504
492505 case ( . value( . module( let module) ) , . dictionary( let values) ) :
493- try module. update ( parameters: NestedDictionary ( values: values) , verify: verify)
506+ try module. update (
507+ parameters: NestedDictionary ( values: values) , verify: verify, path: path,
508+ modulePath: modulePath)
494509
495510 case ( . value( . module( let module) ) , . none) :
496- try module. update ( parameters: NestedDictionary ( ) , verify: verify)
511+ try module. update (
512+ parameters: NestedDictionary ( ) , verify: verify, path: path,
513+ modulePath: modulePath)
497514
498515 case ( . none, . none) , ( . value( . none) , . none) , ( . value( . other( _) ) , . none) :
499516 break
500517
501518 default :
502- fatalError ( " Unable to set \( key) on \( self ) : \( item) not compatible with \( value) " )
519+ fatalError (
520+ " Unable to set \( path. joined ( separator: " . " ) ) on \( modulePath. joined ( separator: " . " ) ) : \( item) not compatible with \( value. mapValues { $0. shape. description } ) "
521+ )
503522 }
504523 }
505524
506525 var processed = Set ( parameters. keys)
507526 for (key, item) in items ( ) {
508527 if let value = parameters [ key] {
509528 processed. remove ( key)
510- try apply ( key: key, item, value)
529+ try apply ( key: key, path : path + [ key ] , item, value)
511530 } else if verify. contains ( . allModelKeysSet) {
512- try apply ( key: key, item, . none)
531+ try apply ( key: key, path : path + [ key ] , item, . none)
513532 }
514533 }
515534
516535 if verify. contains ( . noUnusedKeys) && !processed. isEmpty {
517536 throw UpdateError . unhandledKeys (
518- base : describeType ( self ) , keys: processed. sorted ( ) )
537+ path : path , modules : modulePath , keys: processed. sorted ( ) )
519538 }
520539
521540 return self
@@ -594,9 +613,16 @@ open class Module {
594613 /// - ``leafModules()``
595614 /// - ``QuantizedLinear/quantize(model:groupSize:bits:predicate:)``
596615 @discardableResult
597- open func update( modules: ModuleChildren , verify: VerifyUpdate ) throws -> Self {
616+ open func update(
617+ modules: ModuleChildren , verify: VerifyUpdate , path: [ String ] = [ ] ,
618+ modulePath: [ String ] = [ ]
619+ ) throws -> Self {
598620
599- func apply( key: String , _ item: ModuleItem , _ value: NestedItem < String , Module > ) throws {
621+ let modulePath = modulePath + [ describeType ( self ) ]
622+
623+ func apply(
624+ key: String , path: [ String ] , _ item: ModuleItem , _ value: NestedItem < String , Module >
625+ ) throws {
600626 // item: single item from `items()`
601627 // value: single item with matching structure from `children()`
602628 //
@@ -605,7 +631,7 @@ open class Module {
605631 switch ( item, value) {
606632 case ( . value( . parameters) , . value) :
607633 fatalError (
608- " Unable to set \( key ) on \( self ) : parameters (MLXArray) cannot be updated with a Module "
634+ " Unable to set \( path . joined ( separator : " . " ) ) on \( modulePath . joined ( separator : " . " ) ) : parameters (MLXArray) cannot be updated with a Module "
609635 )
610636
611637 case ( . array( let items) , . array( let values) ) :
@@ -634,17 +660,17 @@ open class Module {
634660 default :
635661 // otherwise we don't know how to update it
636662 throw UpdateError . unableToCollectModulesFromContainer (
637- base : describeType ( self ) , key : key )
663+ path : path , modules : modulePath )
638664 }
639665 } else {
640666 // past the end of items
641667 throw UpdateError . unableToCollectModulesFromContainer (
642- base : describeType ( self ) , key : key )
668+ path : path , modules : modulePath )
643669 }
644670
645671 default :
646672 throw UpdateError . unableToCollectModulesFromContainer (
647- base : describeType ( self ) , key : key )
673+ path : path , modules : modulePath )
648674 }
649675 }
650676
@@ -678,7 +704,7 @@ open class Module {
678704 newModules [ item. key] = module
679705 default :
680706 throw UpdateError . unableToCollectModulesFromContainer (
681- base : describeType ( self ) , key : key )
707+ path : path , modules : modulePath )
682708 }
683709 }
684710
@@ -697,21 +723,23 @@ open class Module {
697723 try module. update ( modules: NestedDictionary ( values: values) , verify: verify)
698724
699725 default :
700- fatalError ( " Unable to set \( key) on \( self ) : \( item) not compatible with \( value) " )
726+ fatalError (
727+ " Unable to set \( path. joined ( separator: " . " ) ) on \( modulePath. joined ( separator: " . " ) ) : \( item) not compatible with \( value) "
728+ )
701729 }
702730 }
703731
704732 var processed = Set ( modules. keys)
705733 for (key, item) in items ( ) {
706734 if let value = modules [ key] {
707735 processed. remove ( key)
708- try apply ( key: key, item, value)
736+ try apply ( key: key, path : path + [ key ] , item, value)
709737 }
710738 }
711739
712740 if verify. contains ( . noUnusedKeys) && !processed. isEmpty {
713741 throw UpdateError . unhandledKeys (
714- base : describeType ( self ) , keys: processed. sorted ( ) )
742+ path : path , modules : modulePath , keys: processed. sorted ( ) )
715743 }
716744
717745 // rebuild the caches because the modules may have changed
@@ -806,7 +834,7 @@ open class Module {
806834 let localKeys = Set ( localKeys)
807835 for key in keys {
808836 if !localKeys. contains ( key) {
809- throw UpdateError . keyNotFound ( base : describeType ( self ) , key : key )
837+ throw UpdateError . keyNotFound ( path : [ key ] , modules : [ describeType ( self ) ] )
810838 }
811839 }
812840 }
@@ -1537,36 +1565,40 @@ private protocol TypeErasedSetterProvider {
15371565}
15381566
15391567enum UpdateError : Error {
1540- case unableToCollectModulesFromContainer( base : String , key : String )
1568+ case unableToCollectModulesFromContainer( path : [ String ] , modules : [ String ] )
15411569 case mismatchedContainers( base: String , key: String )
1542- case mismatchedSize( key : String , expectedShape: [ Int ] , actualShape: [ Int ] )
1543- case keyNotFound( base : String , key : String )
1570+ case mismatchedSize( path : [ String ] , modules : [ String ] , expectedShape: [ Int ] , actualShape: [ Int ] )
1571+ case keyNotFound( path : [ String ] , modules : [ String ] )
15441572 case needModuleInfo( String )
15451573 case unableToSet( String )
15461574 case unableToCast( String )
1547- case unhandledKeys( base : String , keys: [ String ] )
1575+ case unhandledKeys( path : [ String ] , modules : [ String ] , keys: [ String ] )
15481576}
15491577
15501578extension UpdateError : LocalizedError {
15511579 var errorDescription : String ? {
15521580 switch self {
1553- case . unableToCollectModulesFromContainer( let base, let key) :
1554- return " Unable to collect modules from container: \( base) \( key) "
1581+ case . unableToCollectModulesFromContainer( let path, let modules) :
1582+ return
1583+ " Unable to collect modules from container: \( path. joined ( separator: " . " ) ) in \( modules. joined ( separator: " . " ) ) "
15551584 case . mismatchedContainers( let base, let key) :
15561585 return " Mismatched containers: \( base) \( key) "
1557- case let . mismatchedSize( key: key, expectedShape: expectedShape, actualShape: actualShape) :
1586+ case let . mismatchedSize(
1587+ path, modules, expectedShape: expectedShape, actualShape: actualShape) :
15581588 return
1559- " Mismatched parameter \( key) shape. Actual \( actualShape) , expected \( expectedShape) "
1560- case . keyNotFound( let base, let key) :
1561- return " Key \( key) not found in \( base) "
1589+ " Mismatched parameter \( path. joined ( separator: " . " ) ) in \( modules. joined ( separator: " . " ) ) shape. Actual \( actualShape) , expected \( expectedShape) "
1590+ case . keyNotFound( let path, let modules) :
1591+ return
1592+ " Key \( path. joined ( separator: " . " ) ) not found in \( modules. joined ( separator: " . " ) ) "
15621593 case . needModuleInfo( let string) :
15631594 return string
15641595 case . unableToSet( let string) :
15651596 return string
15661597 case . unableToCast:
15671598 return " Unable to cast value "
1568- case . unhandledKeys( let base, let keys) :
1569- return " Unhandled keys \( keys) in \( base) "
1599+ case . unhandledKeys( let path, let modules, let keys) :
1600+ return
1601+ " Unhandled keys \( keys) in \( path. joined ( separator: " . " ) ) in \( modules. joined ( separator: " . " ) ) "
15701602 }
15711603 }
15721604}
0 commit comments