|
1 | 1 | /** |
2 | | - * @Description : |
| 2 | + * @Description : |
3 | 3 | * @Author : Azure-Tang, Boxin Zhang |
4 | 4 | * @Date : 2024-07-25 13:38:30 |
5 | 5 | * @Version : 0.2.2 |
6 | | - * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. |
| 6 | + * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. |
7 | 7 | **/ |
8 | 8 |
|
9 | 9 | #include "custom_gguf/ops.h" |
|
20 | 20 |
|
21 | 21 | PYBIND11_MODULE(KTransformersOps, m) { |
22 | 22 |
|
23 | | - m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
24 | | - return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 23 | + m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { |
| 24 | + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); |
| 25 | + return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); |
25 | 26 | }, "Function to dequantize q8_0 data.", |
26 | 27 | py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
27 | 28 |
|
28 | | - m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
29 | | - return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 29 | + m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { |
| 30 | + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); |
| 31 | + return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); |
30 | 32 | }, "Function to dequantize q6_k data.", |
31 | 33 | py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
32 | 34 |
|
33 | | - m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
34 | | - return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 35 | + m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { |
| 36 | + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); |
| 37 | + return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); |
35 | 38 | }, "Function to dequantize q5_k data.", |
36 | 39 | py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
37 | 40 |
|
38 | | - m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
39 | | - return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 41 | + m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { |
| 42 | + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); |
| 43 | + return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); |
40 | 44 | }, "Function to dequantize q4_k data.", |
41 | 45 | py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
42 | 46 |
|
43 | | - m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
44 | | - return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 47 | + m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { |
| 48 | + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); |
| 49 | + return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); |
45 | 50 | }, "Function to dequantize q3_k data.", |
46 | 51 | py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
47 | 52 |
|
48 | | - m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
49 | | - return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 53 | + m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { |
| 54 | + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); |
| 55 | + return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); |
50 | 56 | }, "Function to dequantize q2_k data.", |
51 | 57 | py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
52 | 58 |
|
53 | | - m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) { |
54 | | - return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype); |
| 59 | + m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, py::object target_dtype) { |
| 60 | + torch::Dtype dtype = torch::python::detail::py_object_to_dtype(target_dtype); |
| 61 | + return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, dtype); |
55 | 62 | }, "Function to dequantize iq4_xs data.", |
56 | 63 | py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype")); |
57 | 64 |
|
|
0 commit comments