Skip to content

Commit 1ef13b1

Browse files
authored
enable 8-bits for exllama eora kernel (unfused) (#1367)
Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
1 parent 9c7349e commit 1ef13b1

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

gptqmodel/nn_modules/qlinear/exllama_eora.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
except ImportError as e:
3333
exllama_v2v_import_exception = e
3434

35-
logger = setup_logger()
35+
log = setup_logger()
3636

3737

3838
# TODO remove this?
@@ -54,7 +54,7 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
5454

5555

5656
class ExllamaEoraQuantLinear(BaseQuantLinear):
57-
SUPPORTS_BITS = [4] # fused eora only validated for 4 bits
57+
SUPPORTS_BITS = [4, 8] # fused eora only validated for 4 bits
5858
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
5959
SUPPORTS_DESC_ACT = [True, False]
6060
SUPPORTS_SYM = [True] # TODO: validate False
@@ -156,32 +156,48 @@ def post_init(self):
156156
def forward(self, x):
157157
x_dtype = x.dtype
158158
if x_dtype != torch.float16:
159-
logger.warning_once(
159+
log.warn.once(
160160
f"Exllama EoRA kernel requires a float16 input activation, while {x.dtype} was passed. Casting to float16.\nMake sure you loaded your model with torch_dtype=torch.float16, that the model definition does not inadvertently cast to float32, or disable AMP Autocast that may produce float32 intermediate activations in the model."
161161
)
162162

163163
x = x.to(dtype=torch.float16)
164164

165165
# sync with vllm
166-
out_shape = x.shape[:-1] + (self.qweight.shape[-1],)
166+
# log.info(f"x shape: {x.shape}")
167+
# log.info(f"qweight shape: {self.qweight.shape}")
168+
# log.info(f"in_features: {self.in_features}")
169+
# log.info(f"out_features: {self.out_features}")
170+
# log.info(f"x.shape[:-1]: {x.shape[:-1]}")
171+
# log.info(f"self.qweight.shape[-1],: {self.qweight.shape[-1],}")
172+
173+
out_shape = x.shape[:-1] + (self.out_features,)
167174
reshaped_x = x.reshape(-1, x.shape[-1])
168175

176+
# log.info(f"out_shape: {out_shape}")
177+
# log.info(f"reshaped_x: {reshaped_x.shape}")
178+
169179
# TODO: need to run checks to make sure there is no performance regression padding with F.pad
170180
# if in_features is padded, we need to pad the input as well
171181
# if x.size(-1) != self.in_features:
172182
# x = F.pad(x, self.in_features_padding_shape)
173183

174184
if self.adapter:
175-
output = gptq_gemm_lora(x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits, x @ self.adapter.lora_A, self.adapter.lora_B) # fused
176-
# output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits).add_((reshaped_x @ self.adapter.lora_A) @ self.adapter.lora_B) # normal
185+
# only 4 bits fused eora kernel has been validated
186+
if self.bits == 4:
187+
output = gptq_gemm_lora(x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits, x @ self.adapter.lora_A, self.adapter.lora_B) # fused
188+
else:
189+
output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits).add_((reshaped_x @ self.adapter.lora_A) @ self.adapter.lora_B) # normal
177190
else:
178191
output = gptq_gemm(reshaped_x, self.qweight, self.qzeros, self.scales, self.g_idx, self.bits)
179192

180193

181194
if self.bias is not None:
182195
output.add_(self.bias)
183196

197+
# log.info(f"output: {output.shape}")
198+
184199
# sync with vllm
185200
output = output.reshape(out_shape)
201+
# log.info(f"output reshaped: {output.shape}")
186202

187203
return output.to(dtype=x_dtype)

0 commit comments

Comments
 (0)