Skip to content

Commit 5976316

Browse files
davidkoskiawni
andauthored
Add quantize function that can return different parameters per layer (#229)
* Add quantize function that can return different parameters per layer Co-authored-by: Awni Hannun <awni@apple.com>
1 parent bef30f0 commit 5976316

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

Source/MLXNN/Quantized.swift

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ public protocol Quantized: Module {
1616
var bits: Int { get }
1717
}
1818

19+
/// Quantize any ``Quantizable`` layer that is not already quantized.
1920
public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) -> Quantized? {
2021
if layer is Quantized {
2122
// already quantized
@@ -37,6 +38,8 @@ public func quantizeSingle(layer: Module, groupSize: Int = 64, bits: Int = 4) ->
3738
/// - bits: bits per parameter
3839
/// - filter: filter receiving path and module -- return `false` to skip a layer
3940
/// - apply: function to attempt the quantization -- the default implementation will quantize ``Linear`` and ``Embedding``
41+
/// ### See Also
42+
/// - ``quantize(model:filter:apply:)``
4043
public func quantize(
4144
model: Module, groupSize: Int = 64, bits: Int = 4,
4245
filter: (String, Module) -> Bool = { _, _ in true },
@@ -59,6 +62,38 @@ public func quantize(
5962
model.update(modules: ModuleChildren.unflattened(updates))
6063
}
6164

65+
/// Quantize the sub-modules of a module according to a filter.
66+
///
67+
/// By default all ``Linear`` and ``Embedding`` layers will be quantized.
68+
///
69+
/// - Parameters:
70+
/// - model: model to quantize
71+
/// - filter: filter receiving path and module -- return a tuple of `(groupSize: Int, bits: Int)` or `nil` to skip quantization
72+
/// - apply: function to attempt the quantization -- the default implementation will quantize ``Linear`` and ``Embedding`` layers
73+
/// ### See Also
74+
/// - ``quantize(model:groupSize:bits:filter:apply:)``
75+
public func quantize(
76+
model: Module,
77+
filter: (String, Module) -> (groupSize: Int, bits: Int)?,
78+
apply: (Module, Int, Int) -> Module? = quantizeSingle(layer:groupSize:bits:)
79+
) {
80+
let updates =
81+
model
82+
.leafModules()
83+
.flattened()
84+
.compactMap { (path, m) -> (String, Module)? in
85+
if let (groupSize, bits) = filter(path, m) {
86+
if let quantized = apply(m, groupSize, bits) {
87+
return (path, quantized)
88+
}
89+
}
90+
91+
return nil
92+
}
93+
94+
model.update(modules: ModuleChildren.unflattened(updates))
95+
}
96+
6297
/// The same as ``Embedding`` but with a quantized weight matrix.
6398
open class QuantizedEmbedding: Embedding, Quantized {
6499

0 commit comments

Comments
 (0)