@@ -16,6 +16,7 @@ public protocol Quantized: Module {
1616 var bits : Int { get }
1717}
1818
19+ /// Quantize any ``Quantizable`` layer that is not already quantized.
1920public func quantizeSingle( layer: Module , groupSize: Int = 64 , bits: Int = 4 ) -> Quantized ? {
2021 if layer is Quantized {
2122 // already quantized
@@ -37,6 +38,8 @@ public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) ->
3738/// - bits: bits per parameter
3839/// - filter: filter receiving path and module -- return `false` to skip a layer
3940/// - apply: function to attempt the quantization -- the default implementation will quantize ``Linear`` and ``Embedding``
41+ /// ### See Also
42+ /// - ``quantize(model:filter:apply:)``
4043public func quantize(
4144 model: Module , groupSize: Int = 64 , bits: Int = 4 ,
4245 filter: ( String , Module ) -> Bool = { _, _ in true } ,
@@ -59,6 +62,38 @@ public func quantize(
5962 model. update ( modules: ModuleChildren . unflattened ( updates) )
6063}
6164
65+ /// Quantize the sub-modules of a module according to a filter.
66+ ///
67+ /// By default all ``Linear`` and ``Embedding`` layers will be quantized.
68+ ///
69+ /// - Parameters:
70+ /// - model: model to quantize
71+ /// - filter: filter receiving path and module -- return a tuple of `(groupSize: Int, bits: Int)` or `nil` to skip quantization
72+ /// - apply: function to attempt the quantization -- the default implementation will quantize ``Linear`` and ``Embedding`` layers
73+ /// ### See Also
74+ /// - ``quantize(model:groupSize:bits:filter:apply:)``
75+ public func quantize(
76+ model: Module ,
77+ filter: ( String , Module ) -> ( groupSize: Int , bits: Int ) ? ,
78+ apply: ( Module , Int , Int ) -> Module ? = quantizeSingle ( layer: groupSize: bits: )
79+ ) {
80+ let updates =
81+ model
82+ . leafModules ( )
83+ . flattened ( )
84+ . compactMap { ( path, m) -> ( String , Module ) ? in
85+ if let ( groupSize, bits) = filter ( path, m) {
86+ if let quantized = apply ( m, groupSize, bits) {
87+ return ( path, quantized)
88+ }
89+ }
90+
91+ return nil
92+ }
93+
94+ model. update ( modules: ModuleChildren . unflattened ( updates) )
95+ }
96+
6297/// The same as ``Embedding`` but with a quantized weight matrix.
6398open class QuantizedEmbedding : Embedding , Quantized {
6499
0 commit comments