Skip to content

Commit ca9d634

Browse files
fix qqq quant error (#1498)
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
1 parent a83003a commit ca9d634

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

gptqmodel/looper/qqq_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ..looper.named_module import NamedModule
2626
from ..models import BaseGPTQModel
2727
from ..models.writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE,
28-
PROCESS_LOG_NAME, PROCESS_LOG_TIME, QUANT_LOG_DAMP, QUANT_LOG_LOSS)
28+
PROCESS_LOG_NAME, PROCESS_LOG_TIME, QUANT_LOG_DAMP, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES)
2929
from ..quantization.config import QUANT_METHOD, QuantizeConfig
3030
from ..quantization.gptq import CPU
3131
from ..quantization.qqq import QQQ
@@ -121,7 +121,7 @@ def process(self, module: NamedModule):
121121
# logger.info(f"Quantizing module START: {name}, {gptq[name].shape()}")
122122
## Need to return the quantized_weight for offloading
123123
g = gptq[module.name]
124-
wq, scale, zero, g_idx, duration, avg_loss, damp_percent, scale_extra = g.quantize()
124+
wq, scale, zero, g_idx, duration, avg_loss, damp_percent, scale_extra, nsamples = g.quantize()
125125
## Assign the quantized weight to the weight
126126
#gptq[name].layer.weight.data = q_full_weight.to(device=gptq[name].device)
127127

@@ -151,6 +151,7 @@ def process(self, module: NamedModule):
151151
PROCESS_LOG_LAYER: module.layer_index,
152152
PROCESS_LOG_MODULE: module.name,
153153
QUANT_LOG_LOSS: f"{avg_loss:.5f}",
154+
QUANT_LOG_NSAMPLES: f"{nsamples}",
154155
QUANT_LOG_DAMP: f"{damp_percent:.5f}",
155156
PROCESS_LOG_TIME: f"{duration:.3f}",
156157
PROCESS_LOG_FWD_TIME: f"{self.fwd_time:.3f}",

gptqmodel/quantization/qqq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def quantize(
1515
self,
1616
blocksize=128,
1717
):
18-
wq, scale, zero, g_idx, duration, avg_loss, damp_percent = super().quantize(blocksize=blocksize)
18+
wq, scale, zero, g_idx, duration, avg_loss, damp_percent, nsamples = super().quantize(blocksize=blocksize)
1919

2020
# post int8 quant
2121
scale_extra = None
@@ -32,4 +32,4 @@ def quantize(
3232
)
3333
quantizer_extra.find_params(self.module.weight.data.clone(), weight=True)
3434
scale_extra = quantizer_extra.scale
35-
return wq, scale, zero, g_idx, duration, avg_loss, damp_percent, scale_extra
35+
return wq, scale, zero, g_idx, duration, avg_loss, damp_percent, scale_extra, nsamples

0 commit comments

Comments
 (0)