Skip to content

Commit d2d0e11

Browse files
committed
Update quantization kernels
1 parent 778b61c commit d2d0e11

File tree

5 files changed

+54
-37
lines changed

5 files changed

+54
-37
lines changed

flake.lock

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

flake.nix

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
inputs.nixpkgs.follows = "hf-nix/nixpkgs";
66
};
77
nix-filter.url = "github:numtide/nix-filter";
8-
hf-nix.url = "github:huggingface/hf-nix";
8+
hf-nix.url = "github:huggingface/hf-nix/quantization-0.1.0";
99
nixpkgs.follows = "hf-nix/nixpkgs";
1010
flake-utils.url = "github:numtide/flake-utils";
1111
rust-overlay = {
@@ -33,7 +33,7 @@
3333
};
3434
pkgs = import nixpkgs {
3535
inherit system;
36-
inherit (hf-nix.lib) config;
36+
config = hf-nix.lib.config system;
3737
overlays = [
3838
rust-overlay.overlays.default
3939
hf-nix.overlays.default

server/text_generation_server/layers/marlin/fp8.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,21 @@ def forward(self, A: torch.Tensor) -> torch.Tensor:
7676
assert quantization is not None
7777

7878
A_flat = A.view(-1, A.shape[-1])
79-
C = quantization.fp8_marlin_gemm(
80-
A_flat,
81-
self.qweight,
82-
self.scales,
83-
self.workspace,
84-
8,
85-
A_flat.shape[0],
86-
self.scales.shape[1],
87-
A_flat.shape[1],
79+
C = quantization.gptq_marlin_gemm(
80+
a=A_flat,
81+
c=None,
82+
b_q_weight=self.qweight,
83+
b_scales=self.scales,
84+
global_scale=None,
85+
b_zeros=None,
86+
g_idx=None,
87+
perm=None,
88+
workspace=self.workspace,
89+
b_q_type=quantization.scalar_type.scalar_types.float8_e4m3fn,
90+
size_m=A_flat.shape[0],
91+
size_n=self.scales.shape[1],
92+
size_k=A_flat.shape[1],
93+
use_fp32_reduce=True,
8894
)
8995
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
9096

@@ -143,5 +149,6 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
143149
)
144150

145151
scales = permute_scales(scales)
152+
scales = quantization.marlin_utils_fp8.fp8_fused_exponent_bias_into_scales(scales)
146153

147154
return repacked, scales

server/text_generation_server/layers/marlin/gptq.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ class GPTQMarlinWeight(Weight):
256256
"""
257257

258258
qweight: torch.Tensor
259-
qzeros: torch.Tensor
259+
qzeros: Optional[torch.Tensor]
260260
scales: torch.Tensor
261261
g_idx: torch.Tensor
262262
perm: torch.Tensor
@@ -268,6 +268,7 @@ def __post_init__(self):
268268
assert self.scales.dtype in (torch.float16, torch.bfloat16)
269269
assert self.g_idx.dtype == torch.int32
270270
assert self.perm.dtype == torch.int32
271+
assert self.qzeros is None or self.qzeros.numel() > 0
271272

272273
def get_linear(self, bias: torch.Tensor):
273274
return GPTQMarlinLinear(
@@ -350,9 +351,6 @@ def repack_gptq_for_marlin(
350351
qweight, perm, in_features, out_features, bits
351352
)
352353

353-
if qzeros is None:
354-
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
355-
356354
scales = permute_scales(scales)
357355

358356
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
@@ -392,7 +390,7 @@ def __init__(
392390
if weight.bits not in (4, 8):
393391
raise ValueError("GPTQMarlinLinear only supports 4 and 8-bit quantization")
394392

395-
if weight.qzeros.numel() > 0:
393+
if weight.qzeros is not None:
396394
if weight.bits == 4:
397395
self.quant_type = quantization.scalar_types.uint4
398396
else:
@@ -424,20 +422,21 @@ def forward(self, A: torch.Tensor) -> torch.Tensor:
424422

425423
A_flat = A.view(-1, A.shape[-1])
426424
C = quantization.gptq_marlin_gemm(
427-
A_flat,
428-
self.qweight,
429-
self.scales,
430-
self.qzeros,
431-
self.g_idx,
432-
self.perm,
433-
self.workspace,
434-
self.quant_type,
435-
A_flat.shape[0],
436-
self.scales.shape[1],
437-
A_flat.shape[1],
438-
self.is_full_k,
439-
self.qzeros.numel() > 0,
440-
True,
425+
a=A_flat,
426+
c=None,
427+
b_q_weight=self.qweight,
428+
b_scales=self.scales,
429+
global_scale=None,
430+
b_zeros=self.qzeros,
431+
g_idx=self.g_idx,
432+
perm=self.perm,
433+
workspace=self.workspace,
434+
b_q_type=self.quant_type,
435+
size_m=A_flat.shape[0],
436+
size_n=self.scales.shape[1],
437+
size_k=A_flat.shape[1],
438+
is_k_full=self.is_full_k,
439+
use_fp32_reduce=True,
441440
)
442441
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
443442

server/text_generation_server/layers/moe/gptq_marlin.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,13 @@ def _pack_weight(
202202
device=weight.qweight.device,
203203
)
204204
qzeros = torch.empty(
205-
(n_experts,) + weight.qzeros.shape,
206-
dtype=weight.qzeros.dtype,
207-
device=weight.qzeros.device,
205+
(n_experts,) + ((0,) if weight.qzeros is None else weight.qzeros.shape),
206+
dtype=(
207+
weight.qweight.dtype if weight.qzeros is None else weight.qzeros.dtype
208+
),
209+
device=(
210+
weight.qweight.device if weight.qzeros is None else weight.qzeros.device
211+
),
208212
)
209213
scales = torch.empty(
210214
(n_experts,) + weight.scales.shape,
@@ -232,7 +236,13 @@ def _pack_weight(
232236
)
233237

234238
moe_weight.qweight[expert] = weight.qweight
235-
moe_weight.qzeros[expert] = weight.qzeros
239+
moe_weight.qzeros[expert] = (
240+
torch.zeros(
241+
(0,), device=moe_weight.qzeros.device, dtype=moe_weight.qzeros.dtype
242+
)
243+
if weight.qzeros is None
244+
else weight.qzeros
245+
)
236246
moe_weight.scales[expert] = weight.scales
237247
moe_weight.g_idx[expert] = weight.g_idx
238248
moe_weight.perm[expert] = weight.perm

0 commit comments

Comments
 (0)