Skip to content

Commit 72d09f3

Browse files
authored
Merge pull request #597 from kvcache-ai/feat-more-context
Feat more context
2 parents e908963 + f7f1059 commit 72d09f3

29 files changed

+981
-284
lines changed
Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
/**
2-
* @Description :
3-
* @Author : Azure-Tang
2+
* @Description :
3+
* @Author : Azure-Tang, Boxin Zhang
44
* @Date : 2024-07-25 13:38:30
5-
* @Version : 1.0.0
6-
* @LastEditors : kkk1nak0
7-
* @LastEditTime : 2024-08-12 03:05:04
8-
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
5+
* @Version : 0.2.2
6+
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
97
**/
108

119
#include "custom_gguf/ops.h"
@@ -21,24 +19,46 @@
2119
// namespace py = pybind11;
2220

2321
PYBIND11_MODULE(KTransformersOps, m) {
24-
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
25-
py::arg("data"), py::arg("blk_size"), py::arg("device"));
26-
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
27-
py::arg("data"), py::arg("blk_size"), py::arg("device"));
28-
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
29-
py::arg("data"), py::arg("blk_size"), py::arg("device"));
30-
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
31-
py::arg("data"), py::arg("blk_size"), py::arg("device"));
32-
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
33-
py::arg("data"), py::arg("blk_size"), py::arg("device"));
34-
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
35-
py::arg("data"), py::arg("blk_size"), py::arg("device"));
36-
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
37-
py::arg("data"), py::arg("blk_size"), py::arg("device"));
22+
23+
m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
24+
return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
25+
}, "Function to dequantize q8_0 data.",
26+
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
27+
28+
m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
29+
return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
30+
}, "Function to dequantize q6_k data.",
31+
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
32+
33+
m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
34+
return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
35+
}, "Function to dequantize q5_k data.",
36+
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
37+
38+
m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
39+
return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
40+
}, "Function to dequantize q4_k data.",
41+
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
42+
43+
m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
44+
return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
45+
}, "Function to dequantize q3_k data.",
46+
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
47+
48+
m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
49+
return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
50+
}, "Function to dequantize q2_k data.",
51+
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
52+
53+
m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
54+
return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
55+
}, "Function to dequantize iq4_xs data.",
56+
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
57+
3858
#ifdef KTRANSFORMERS_USE_CUDA
39-
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
40-
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
41-
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
42-
py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"));
59+
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
60+
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
61+
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
62+
py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"));
4363
#endif
4464
}

ktransformers/ktransformers_ext/cuda/custom_gguf/binding.cpp

Lines changed: 0 additions & 35 deletions
This file was deleted.

0 commit comments

Comments
 (0)