Skip to content

Commit cd1fce6

Browse files
authored
SYCL: Add set_rows support for quantized types (#14883)
* SYCL: Add set_rows support for quantized types This commit adds support for GGML_OP_SET_ROWS operation for various quantized tensor types (Q8_0, Q5_1, Q5_0, Q4_1, Q4_0, IQ4_NL) and BF16 type in the SYCL backend. The quantization/dequantization copy kernels were moved from cpy.cpp to cpy.hpp to make them available for set_rows.cpp. This addresses part of the TODOs mentioned in the code. * Use get_global_linear_id() instead ggml-ci * Fix formatting ggml-ci * Use const for ne11 and size_t variables in set_rows_sycl_q ggml-ci * Increase block size for q kernel to 256 ggml-ci * Cleanup imports * Add float.h to cpy.hpp
1 parent 00fa15f commit cd1fce6

File tree

4 files changed

+313
-218
lines changed

4 files changed

+313
-218
lines changed

ggml/src/ggml-sycl/cpy.cpp

Lines changed: 0 additions & 212 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,12 @@
11
#include "cpy.hpp"
22

33
#include <float.h>
4-
#include <string>
54

65
#include "dequantize.hpp"
76
#include "ggml-sycl/common.hpp"
87
#include "ggml-sycl/presets.hpp"
98
#include "ggml.h"
109

11-
static __dpct_inline__ int best_index_int8(int n, const int8_t * val, float x) {
12-
if (x <= val[0]) {
13-
return 0;
14-
}
15-
if (x >= val[n - 1]) {
16-
return n - 1;
17-
}
18-
int ml = 0, mu = n - 1;
19-
while (mu - ml > 1) {
20-
int mav = (ml + mu) / 2;
21-
if (x < val[mav]) {
22-
mu = mav;
23-
} else {
24-
ml = mav;
25-
}
26-
}
27-
return x - val[mu - 1] < val[mu] - x ? mu - 1 : mu;
28-
}
2910

3011
static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
3112
const float * xi = (const float *) cxi;
@@ -97,28 +78,6 @@ static void cpy_f32_f16(const char * cx, char * cdst, const int ne, const int ne
9778
cpy_1(cx + x_offset, cdst + dst_offset);
9879
}
9980

100-
static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
101-
const float * xi = (const float *) cxi;
102-
block_q8_0 * dsti = (block_q8_0 *) cdsti;
103-
104-
float amax = 0.0f; // absolute max
105-
106-
for (int j = 0; j < QK8_0; j++) {
107-
const float v = xi[j];
108-
amax = sycl::fmax(amax, sycl::fabs((float) v));
109-
}
110-
111-
const float d = amax / ((1 << 7) - 1);
112-
const float id = d ? 1.0f / d : 0.0f;
113-
114-
dsti->d = d;
115-
116-
for (int j = 0; j < QK8_0; ++j) {
117-
const float x0 = xi[j] * id;
118-
119-
dsti->qs[j] = sycl::round((float) x0);
120-
}
121-
}
12281

12382
/* quantized type same copy */
12483
template<typename T>
@@ -140,178 +99,7 @@ static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
14099
}
141100
}
142101

143-
static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
144-
const float * xi = (const float *) cxi;
145-
block_q4_0 * dsti = (block_q4_0 *) cdsti;
146-
147-
float amax = 0.0f;
148-
float vmax = 0.0f;
149-
150-
for (int j = 0; j < QK4_0; ++j) {
151-
const float v = xi[j];
152-
if (amax < sycl::fabs((float) v)) {
153-
amax = sycl::fabs((float) v);
154-
vmax = v;
155-
}
156-
}
157-
158-
const float d = vmax / -8;
159-
const float id = d ? 1.0f / d : 0.0f;
160-
161-
dsti->d = d;
162-
163-
for (int j = 0; j < QK4_0 / 2; ++j) {
164-
const float x0 = xi[0 + j] * id;
165-
const float x1 = xi[QK4_0 / 2 + j] * id;
166-
167-
const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 8.5f));
168-
const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 8.5f));
169-
170-
dsti->qs[j] = xi0;
171-
dsti->qs[j] |= xi1 << 4;
172-
}
173-
}
174-
175-
static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
176-
const float * xi = (const float *) cxi;
177-
block_q4_1 * dsti = (block_q4_1 *) cdsti;
178-
179-
float vmin = FLT_MAX;
180-
float vmax = -FLT_MAX;
181-
182-
for (int j = 0; j < QK4_1; ++j) {
183-
const float v = xi[j];
184-
185-
if (v < vmin) {
186-
vmin = v;
187-
}
188-
if (v > vmax) {
189-
vmax = v;
190-
}
191-
}
192-
193-
const float d = (vmax - vmin) / ((1 << 4) - 1);
194-
const float id = d ? 1.0f / d : 0.0f;
195-
196-
dsti->dm.x() = d;
197-
dsti->dm.y() = vmin;
198-
199-
for (int j = 0; j < QK4_1 / 2; ++j) {
200-
const float x0 = (xi[0 + j] - vmin) * id;
201-
const float x1 = (xi[QK4_1 / 2 + j] - vmin) * id;
202-
203-
const uint8_t xi0 = dpct::min(15, (int8_t) (x0 + 0.5f));
204-
const uint8_t xi1 = dpct::min(15, (int8_t) (x1 + 0.5f));
205102

206-
dsti->qs[j] = xi0;
207-
dsti->qs[j] |= xi1 << 4;
208-
}
209-
}
210-
211-
static void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
212-
const float * xi = (const float *) cxi;
213-
block_q5_0 * dsti = (block_q5_0 *) cdsti;
214-
215-
float amax = 0.0f;
216-
float vmax = 0.0f;
217-
218-
for (int j = 0; j < QK5_0; ++j) {
219-
const float v = xi[j];
220-
if (amax < sycl::fabs((float) v)) {
221-
amax = sycl::fabs((float) v);
222-
vmax = v;
223-
}
224-
}
225-
226-
const float d = vmax / -16;
227-
const float id = d ? 1.0f / d : 0.0f;
228-
229-
dsti->d = d;
230-
231-
uint32_t qh = 0;
232-
for (int j = 0; j < QK5_0 / 2; ++j) {
233-
const float x0 = xi[0 + j] * id;
234-
const float x1 = xi[QK5_0 / 2 + j] * id;
235-
236-
const uint8_t xi0 = dpct::min(31, (int8_t) (x0 + 16.5f));
237-
const uint8_t xi1 = dpct::min(31, (int8_t) (x1 + 16.5f));
238-
239-
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
240-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
241-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0 / 2);
242-
}
243-
memcpy(dsti->qh, &qh, sizeof(qh));
244-
}
245-
246-
static void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
247-
const float * xi = (const float *) cxi;
248-
block_q5_1 * dsti = (block_q5_1 *) cdsti;
249-
250-
float min = xi[0];
251-
float max = xi[0];
252-
253-
for (int j = 1; j < QK5_1; ++j) {
254-
const float v = xi[j];
255-
min = v < min ? v : min;
256-
max = v > max ? v : max;
257-
}
258-
259-
const float d = (max - min) / 31;
260-
const float id = d ? 1.0f / d : 0.0f;
261-
262-
dsti->dm.x() = d;
263-
dsti->dm.y() = min;
264-
265-
uint32_t qh = 0;
266-
for (int j = 0; j < QK5_1 / 2; ++j) {
267-
const float x0 = (xi[0 + j] - min) * id;
268-
const float x1 = (xi[QK5_1 / 2 + j] - min) * id;
269-
270-
const uint8_t xi0 = (uint8_t) (x0 + 0.5f);
271-
const uint8_t xi1 = (uint8_t) (x1 + 0.5f);
272-
273-
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
274-
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
275-
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1 / 2);
276-
}
277-
memcpy(dsti->qh, &qh, sizeof(qh));
278-
}
279-
280-
static void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
281-
const float * xi = (const float *) cxi;
282-
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
283-
284-
float amax = 0.0f;
285-
float vmax = 0.0f;
286-
287-
for (int j = 0; j < QK4_NL; ++j) {
288-
const float v = xi[j];
289-
if (amax < sycl::fabs((float) v)) {
290-
amax = sycl::fabs((float) v);
291-
vmax = v;
292-
}
293-
}
294-
295-
float d = vmax / kvalues_iq4nl[0];
296-
const float id = d ? 1.0f / d : 0.0f;
297-
298-
float sumqx = 0, sumq2 = 0;
299-
for (int j = 0; j < QK4_NL / 2; ++j) {
300-
const float x0 = xi[0 + j] * id;
301-
const float x1 = xi[QK4_NL / 2 + j] * id;
302-
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
303-
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
304-
dsti->qs[j] = xi0 | (xi1 << 4);
305-
const float v0 = kvalues_iq4nl[xi0];
306-
const float v1 = kvalues_iq4nl[xi1];
307-
const float w0 = xi[0 + j] * xi[0 + j];
308-
const float w1 = xi[QK4_NL / 2 + j] * xi[QK4_NL / 2 + j];
309-
sumqx += w0 * v0 * xi[j] + w1 * v1 * xi[QK4_NL / 2 + j];
310-
sumq2 += w0 * v0 * v0 + w1 * v1 * v1;
311-
}
312-
313-
dsti->d = sumq2 > 0 ? sumqx / sumq2 : d;
314-
}
315103

316104
template <dequantize_kernel_t dequant, int qk> static void cpy_blck_q_f32(const char * cxi, char * cdsti) {
317105
float * cdstf = (float *) (cdsti);

0 commit comments

Comments
 (0)