|
1 | 1 | /** |
2 | | - * @Description : |
3 | | - * @Author : Azure-Tang |
| 2 | + * @Description : |
| 3 | + * @Author : Azure-Tang, Boxin Zhang |
4 | 4 | * @Date : 2024-07-25 13:38:30 |
5 | | - * @Version : 1.0.0 |
6 | | - * @LastEditors : kkk1nak0 |
7 | | - * @LastEditTime : 2024-08-12 03:05:04 |
8 | | - * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. |
| 5 | + * @Version : 0.2.2 |
| 6 | + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. |
9 | 7 | **/ |
10 | 8 |
|
11 | 9 | #include "custom_gguf/ops.h" |
|
21 | 19 | // namespace py = pybind11; |
22 | 20 |
|
23 | 21 | PYBIND11_MODULE(KTransformersOps, m) { |
24 | | - m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", |
25 | | - py::arg("data"), py::arg("blk_size"), py::arg("device")); |
26 | | - m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", |
27 | | - py::arg("data"), py::arg("blk_size"), py::arg("device")); |
28 | | - m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", |
29 | | - py::arg("data"), py::arg("blk_size"), py::arg("device")); |
30 | | - m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", |
31 | | - py::arg("data"), py::arg("blk_size"), py::arg("device")); |
32 | | - m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", |
33 | | - py::arg("data"), py::arg("blk_size"), py::arg("device")); |
34 | | - m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", |
35 | | - py::arg("data"), py::arg("blk_size"), py::arg("device")); |
36 | | - m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", |
37 | | - py::arg("data"), py::arg("blk_size"), py::arg("device")); |
| 22 | + |
| 23 | + m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
| 24 | + return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 25 | + }, "Function to dequantize q8_0 data.", |
| 26 | + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
| 27 | + |
| 28 | + m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
| 29 | + return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 30 | + }, "Function to dequantize q6_k data.", |
| 31 | + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
| 32 | + |
| 33 | + m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
| 34 | + return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 35 | + }, "Function to dequantize q5_k data.", |
| 36 | + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
| 37 | + |
| 38 | + m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
| 39 | + return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 40 | + }, "Function to dequantize q4_k data.", |
| 41 | + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
| 42 | + |
| 43 | + m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
| 44 | + return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 45 | + }, "Function to dequantize q3_k data.", |
| 46 | + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
| 47 | + |
| 48 | + m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
| 49 | + return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 50 | + }, "Function to dequantize q2_k data.", |
| 51 | + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
| 52 | + |
| 53 | + m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
| 54 | + return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 55 | + }, "Function to dequantize iq4_xs data.", |
| 56 | + py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
| 57 | + |
38 | 58 | #ifdef KTRANSFORMERS_USE_CUDA |
39 | | - m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", |
40 | | - py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), |
41 | | - py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), |
42 | | - py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full")); |
| 59 | + m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", |
| 60 | + py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), |
| 61 | + py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), |
| 62 | + py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full")); |
43 | 63 | #endif |
44 | 64 | } |
0 commit comments