Skip to content

Commit 037c5c0

Browse files
authored
Add Group Aware Reordering (GAR) support and config option (#1656)
1 parent 5d2911a commit 037c5c0

File tree

4 files changed

+127
-0
lines changed

4 files changed

+127
-0
lines changed

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,23 @@ dynamic = {
396396
* Pass `auto_gc = False` to `quantize()` api to speed up quantization if gpu has plenty of vram and does not need to call slow gc.
397397
* Pass `buffered_fwd = True` to `quantize()` api to potentially speed up quantization if gpu has plenty of vram and can hold all fwd inputs in vram.
398398

399+
#### Group Aware Reordering (GAR)
400+
401+
Group Aware Reordering (GAR) is an enhanced activation reordering scheme designed to significantly improve the accuracy of quantized models without incurring additional inference overhead. Unlike traditional activation reordering, GAR restricts permutations to within individual groups or rearrangements of entire groups. This ensures each group's associated scales and zero-points remain efficiently accessible during inference, thereby avoiding any inference-time overhead.
402+
403+
How to enable GAR:
404+
405+
Set the `hyb_act` parameter to `True` and disable the default activation reordering by setting `desc_act` to `False` in your `QuantizeConfig`. For example:
406+
407+
```python
408+
quant_config = QuantizeConfig(bits=4, group_size=128, desc_act=False, hyb_act=True)
409+
```
410+
411+
This feature is based on the method introduced in:
412+
413+
[T Gafni, A Karnieli, Y Hanani, "Dual Precision Quantization for Efficient and Accurate Deep Neural Networks Inference," CVPR Workshop, 2025.](https://openaccess.thecvf.com/content/CVPR2025W/eLVM/html/Gafni_Dual_Precision_Quantization_for_Efficient_and_Accurate_Deep_Neural_Networks_CVPRW_2025_paper.html)
414+
415+
399416

400417
### Attribution of Quantization Methods:
401418

gptqmodel/quantization/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ class QuantizeConfig():
170170
damp_auto_increment: float = field(default=0.01)
171171

172172
desc_act: bool = field(default=True)
173+
hyb_act: bool = field(default=False)
173174
static_groups: bool = field(default=False)
174175
sym: bool = field(default=True)
175176
true_sequential: bool = field(default=True)
@@ -461,6 +462,7 @@ def to_dict(self):
461462
"dynamic": self.dynamic,
462463
"group_size": self.group_size,
463464
"desc_act": self.desc_act,
465+
"hyb_act": self.hyb_act,
464466
"sym": self.sym,
465467
"lm_head": self.lm_head,
466468
QUANT_METHOD_FIELD:self.quant_method,

gptqmodel/quantization/gar.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
3+
4+
def compute_local_perms(diag_H, groupsize):
5+
"""
6+
For each group, compute a permutation that orders the indices in descending order
7+
based on the corresponding diagonal values of H.
8+
9+
Args:
10+
diag_H (Tensor): 1D tensor representing the diagonal of the Hessian.
11+
groupsize (int): Number of columns/weights per group.
12+
13+
Returns:
14+
local_perms (list of Tensors): Each element is a permutation (indices) for that group.
15+
"""
16+
n = diag_H.numel()
17+
num_groups = n // groupsize
18+
local_perms = []
19+
for g in range(num_groups):
20+
start = g * groupsize
21+
end = start + groupsize
22+
sub_diag = diag_H[start:end]
23+
# Get local permutation: indices that would sort sub_diag in descending order.
24+
local_perm = torch.argsort(sub_diag, descending=True)
25+
local_perms.append(local_perm)
26+
return local_perms
27+
28+
def compute_global_perm(diag_H, groupsize):
29+
"""
30+
Compute a permutation for the groups themselves. Here we choose the maximum diagonal value
31+
within each group as the group metric and sort the groups in descending order.
32+
33+
Args:
34+
diag_H (Tensor): 1D tensor representing the diagonal of the Hessian.
35+
groupsize (int): Number of columns/weights per group.
36+
37+
Returns:
38+
global_perm (Tensor): 1D tensor of length num_groups with the new order of groups.
39+
"""
40+
n = diag_H.numel()
41+
num_groups = n // groupsize
42+
group_metric = []
43+
for g in range(num_groups):
44+
start = g * groupsize
45+
end = start + groupsize
46+
group_metric.append(diag_H[start:end].max().item())
47+
# Create a tensor on the same device as diag_H.
48+
group_metric = torch.tensor(group_metric, device=diag_H.device)
49+
global_perm = torch.argsort(group_metric, descending=True)
50+
return global_perm
51+
52+
def compose_final_perm(local_perms, global_perm, groupsize):
53+
"""
54+
Compose the final overall permutation from the local and global permutations.
55+
56+
Args:
57+
local_perms (list of Tensors): Local permutation for each group.
58+
global_perm (Tensor): Global group permutation.
59+
groupsize (int): Number of indices per group.
60+
61+
Returns:
62+
final_perm (Tensor): 1D tensor that maps original indices to new positions.
63+
"""
64+
num_groups = len(local_perms)
65+
final_perm = []
66+
# Process groups in the order specified by global_perm.
67+
for new_group in range(num_groups):
68+
# Get the original group index.
69+
orig_group = global_perm[new_group].item()
70+
offset = orig_group * groupsize
71+
local_perm = local_perms[orig_group]
72+
# Adjust local indices to the full index space.
73+
for idx in local_perm:
74+
final_perm.append(idx.item() + offset)
75+
return torch.tensor(final_perm, dtype=torch.long)
76+
77+
def invert_perm(perm):
78+
"""
79+
Compute the inverse of a permutation vector.
80+
81+
Args:
82+
perm (Tensor): A 1D tensor containing a permutation of indices.
83+
84+
Returns:
85+
inv (Tensor): The inverse permutation such that inv[perm] == torch.arange(len(perm)).
86+
"""
87+
inv = torch.empty_like(perm)
88+
inv[perm] = torch.arange(perm.numel(), device=perm.device)
89+
return inv

gptqmodel/quantization/gptq.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,14 @@ def quantize(
333333
H = H[perm][:, perm]
334334
invperm = torch.argsort(perm)
335335

336+
if hasattr(self.qcfg, "hyb_act") and self.qcfg.hyb_act and not self.qcfg.desc_act:
337+
from .gar import compute_local_perms, compute_global_perm, compose_final_perm
338+
local_perms = compute_local_perms(torch.diag(H), self.qcfg.group_size)
339+
global_perm = compute_global_perm(torch.diag(H), self.qcfg.group_size)
340+
final_perm = compose_final_perm(local_perms, global_perm, self.qcfg.group_size)
341+
W = W[:, final_perm]
342+
H = H[final_perm][:, final_perm]
343+
336344
Losses = torch.zeros_like(W)
337345
Q = torch.zeros_like(W)
338346

@@ -416,6 +424,17 @@ def quantize(
416424
Q = Q[:, invperm]
417425
g_idx = g_idx[invperm]
418426

427+
if hasattr(self.qcfg, "hyb_act") and self.qcfg.hyb_act and not self.qcfg.desc_act:
428+
from .gar import invert_perm
429+
inv_final = invert_perm(final_perm)
430+
Q = Q[:, inv_final]
431+
inv_global_perm = invert_perm(global_perm)
432+
inv_global_perm_list = inv_global_perm.tolist()
433+
temp_scale = [ scale[i] for i in inv_global_perm_list ]
434+
scale = temp_scale
435+
temp_zero = [ zero[i] for i in inv_global_perm_list ]
436+
zero = temp_zero
437+
419438
if isinstance(self.module, transformers.Conv1D):
420439
Q = Q.t()
421440

0 commit comments

Comments
 (0)