Skip to content

Commit 9c7349e

Browse files
authored
Eora cleanup (#1366)
* rename folder and cleanup namespace Signed-off-by: Qubitium <Qubitium@modelcloud.ai> * auto to float Signed-off-by: Qubitium <Qubitium@modelcloud.ai> --------- Signed-off-by: Qubitium <Qubitium@modelcloud.ai>
1 parent 25f1607 commit 9c7349e

22 files changed

+8
-24
lines changed

gptqmodel/nn_modules/qlinear/exllama_eora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
5454

5555

5656
class ExllamaEoraQuantLinear(BaseQuantLinear):
57-
SUPPORTS_BITS = [4, 8]
57+
SUPPORTS_BITS = [4] # fused eora only validated for 4 bits
5858
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
5959
SUPPORTS_DESC_ACT = [True, False]
6060
SUPPORTS_SYM = [True] # TODO: validate False
File renamed without changes.

gptqmodel_ext/exllama2-vllm/eora/compat.cuh renamed to gptqmodel_ext/exllama_eora/eora/compat.cuh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ Copied from https://github.com/turboderp/exllamav2
55
#ifndef _compat_cuh
66
#define _compat_cuh
77

8-
namespace vllm {
98
namespace gptq {
109
// atomicAdd for half types, to support CC < 7.x
1110

@@ -60,5 +59,4 @@ __device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
6059
#endif
6160

6261
} // namespace gptq
63-
} // namespace vllm
6462
#endif

gptqmodel_ext/exllama2-vllm/eora/matrix_view.cuh renamed to gptqmodel_ext/exllama_eora/eora/matrix_view.cuh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ https://github.com/turboderp/exllama
1111

1212
#include "qdq_util.cuh"
1313

14-
namespace vllm {
1514
namespace gptq {
1615

1716
class MatrixView_half {
@@ -291,5 +290,4 @@ class MatrixView_q8_row {
291290
};
292291

293292
} // namespace gptq
294-
} // namespace vllm
295293
#endif

gptqmodel_ext/exllama2-vllm/eora/q_gemm.cu renamed to gptqmodel_ext/exllama_eora/eora/q_gemm.cu

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa
1919
#include "qdq_4.cuh"
2020
#include "qdq_8.cuh"
2121

22-
namespace vllm {
2322
namespace gptq {
2423

2524
#define BLOCK_KN_SIZE 128
@@ -336,8 +335,8 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel_eora(
336335
for (int j = 0; j < 4; ++j) {
337336
#pragma unroll
338337
for (int m = 0; m < m_count; m++) {
339-
auto a1 = __half2float(*(Ax_.item_ptr(offset_m + m, r)));
340-
auto a2 = __half2float(*(eora_b_.item_ptr(r, n + j)));
338+
float a1 = __half2float(*(Ax_.item_ptr(offset_m + m, r)));
339+
float a2 = __half2float(*(eora_b_.item_ptr(r, n + j)));
341340
float product = a1 * a2;
342341
block_c[m][j] = block_c[m][j] + product;
343342
}
@@ -2074,7 +2073,6 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
20742073
}
20752074

20762075
} // namespace gptq
2077-
} // namespace vllm
20782076

20792077
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
20802078
torch::Tensor b_gptq_qzeros,
@@ -2086,7 +2084,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
20862084
at::Tensor temp_dq = torch::empty(
20872085
{b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options);
20882086

2089-
vllm::gptq::gemm_half_q_half_cuda(
2087+
gptq::gemm_half_q_half_cuda(
20902088
at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(),
20912089
(const uint32_t*)b_q_weight.data_ptr(),
20922090
(const uint32_t*)b_gptq_qzeros.data_ptr(),
@@ -2112,7 +2110,7 @@ torch::Tensor gptq_gemm_lora(torch::Tensor a, torch::Tensor b_q_weight,
21122110
at::Tensor temp_dq = torch::empty(
21132111
{b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options);
21142112

2115-
vllm::gptq::gemm_half_q_half_cuda_eora(
2113+
gptq::gemm_half_q_half_cuda_eora(
21162114
at::cuda::getCurrentCUDABlasHandle(), (const half*)a.data_ptr(),
21172115
(const uint32_t*)b_q_weight.data_ptr(),
21182116
(const uint32_t*)b_gptq_qzeros.data_ptr(),
@@ -2133,7 +2131,7 @@ torch::Tensor gptq_gemm_lora(torch::Tensor a, torch::Tensor b_q_weight,
21332131

21342132
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
21352133
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
2136-
vllm::gptq::shuffle_exllama_weight(
2134+
gptq::shuffle_exllama_weight(
21372135
(uint32_t*)q_weight.data_ptr(),
21382136
q_perm.device().is_meta() || q_perm.numel() == 0
21392137
? NULL

0 commit comments

Comments
 (0)