Skip to content

Commit afc0e89

Browse files
authored
sycl: refactor quantization to q8_1 (#14815)
* sycl: quantization to q8_1 refactor * Refactored src1 copy logic in op_mul_mat
1 parent a5771c9 commit afc0e89

File tree

3 files changed

+184
-206
lines changed

3 files changed

+184
-206
lines changed

ggml/src/ggml-sycl/backend.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mmvq.hpp"
2929
#include "norm.hpp"
3030
#include "outprod.hpp"
31+
#include "quantize.hpp"
3132
#include "quants.hpp"
3233
#include "rope.hpp"
3334
#include "set_rows.hpp"

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 50 additions & 206 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "ggml-sycl/set_rows.hpp"
4545
#include "ggml-sycl/sycl_hw.hpp"
4646
#include "ggml-sycl/getrows.hpp"
47+
#include "ggml-sycl/quantize.hpp"
4748
#include "ggml.h"
4849

4950
static bool g_sycl_loaded = false;
@@ -1373,120 +1374,6 @@ typedef void (*ggml_sycl_op_mul_mat_t)(
13731374

13741375

13751376

1376-
template<int QUANT_BLOCK_TILE>
1377-
static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
1378-
const sycl::nd_item<3> &item_ct1) {
1379-
const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
1380-
item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
1381-
1382-
if (ix >= kx_padded) {
1383-
return;
1384-
}
1385-
1386-
const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
1387-
item_ct1.get_local_id(1);
1388-
1389-
const int i_padded = iy*kx_padded + ix;
1390-
1391-
block_q8_1 * y = (block_q8_1 *) vy;
1392-
1393-
const int ib = i_padded / QK8_1; // block index
1394-
const int iqs = i_padded % QK8_1; // quant index
1395-
typedef sycl::vec<float, QUANT_BLOCK_TILE> TC;
1396-
typedef sycl::vec<int8_t, QUANT_BLOCK_TILE> TQ;
1397-
TC zeros;
1398-
TQ qzeros;
1399-
#pragma unroll
1400-
for (int i = 0; i < QUANT_BLOCK_TILE; i++)
1401-
{
1402-
zeros[i] = 0.f;
1403-
qzeros[i] = 0;
1404-
}
1405-
const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros;
1406-
float sum = xi[0];
1407-
float amax = sycl::fabs(xi[0]);
1408-
#pragma unroll
1409-
for (int i = 1; i < QUANT_BLOCK_TILE; i++)
1410-
{
1411-
sum += xi[i];
1412-
amax = sycl::fmax(sycl::fabs(xi[i]), amax);
1413-
}
1414-
sum = warp_reduce_sum(sum, item_ct1);
1415-
amax = warp_reduce_max(amax, item_ct1);
1416-
1417-
const float d = amax / 127;
1418-
TQ q = qzeros;
1419-
if (amax != 0.0f)
1420-
{
1421-
#pragma unroll
1422-
for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
1423-
q[i] = sycl::round(xi[i] / d);
1424-
}
1425-
}
1426-
1427-
*(TQ *)&y[ib].qs[iqs] = q;
1428-
1429-
if (iqs > 0) {
1430-
return;
1431-
}
1432-
1433-
reinterpret_cast<sycl::half &>(y[ib].ds.x()) = d;
1434-
reinterpret_cast<sycl::half &>(y[ib].ds.y()) = sum;
1435-
}
1436-
1437-
template <int ElementsPerWI>
1438-
static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor,
1439-
const int kx, const int kx_padded, const sycl::nd_item<1> & it) {
1440-
/*
1441-
Quantizes and reorders the resultant q8 tensor in a per row fashion
1442-
Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
1443-
*/
1444-
1445-
auto subgroup_id = it.get_group(0);
1446-
auto wi_id = it.get_local_id(0);
1447-
1448-
const int num_blocks_per_row = kx / QK8_1;
1449-
auto row = subgroup_id / num_blocks_per_row;
1450-
auto col = subgroup_id % num_blocks_per_row;
1451-
1452-
auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1);
1453-
auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
1454-
1455-
auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
1456-
auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2));
1457-
1458-
sycl::vec<float, ElementsPerWI> wi_f32_vals;
1459-
sycl::vec<int8_t, ElementsPerWI> quantized_values;
1460-
1461-
auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
1462-
wi_f32_vals = *reinterpret_cast<const sycl::vec<float, ElementsPerWI> *>(x + float_ptr_offset);
1463-
1464-
float sum = 0.0f;
1465-
float amax = 0.0f;
1466-
1467-
#pragma unroll(ElementsPerWI)
1468-
for (int i = 0; i < ElementsPerWI; i++) {
1469-
sum += wi_f32_vals[i];
1470-
amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i]));
1471-
quantized_values[i] = 0;
1472-
}
1473-
sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus<float>());
1474-
amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum<float>());
1475-
float d = amax == 0 ? 1 : amax / 127;
1476-
1477-
#pragma unroll(ElementsPerWI)
1478-
for (int i = 0; i < ElementsPerWI; i++) {
1479-
quantized_values[i] = sycl::round(wi_f32_vals[i] / d);
1480-
}
1481-
1482-
d = amax == 0 ? 0 : d;
1483-
1484-
*reinterpret_cast<sycl::vec<int8_t, ElementsPerWI> *>(quant_ptr) = quantized_values;
1485-
if (wi_id == 0) {
1486-
*ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum));
1487-
}
1488-
}
1489-
14901377
static void mul_mat_p021_f16_f32(
14911378
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
14921379
const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1770,32 +1657,6 @@ static void pool2d_nchw_kernel(
17701657
o_ptr[cur_oh * ow + cur_ow] = res;
17711658
}
17721659

1773-
static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded,
1774-
bool reorder_q8_tensor, queue_ptr stream) {
1775-
if (reorder_q8_tensor) {
1776-
auto local_range = std::size_t(WARP_SIZE);
1777-
auto num_quant_blocks = ky * (kx / QK8_1);
1778-
auto global_range = num_quant_blocks * local_range;
1779-
stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }),
1780-
[=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1781-
quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
1782-
});
1783-
} else {
1784-
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
1785-
const sycl::range<3> num_blocks(1, ky, block_num_x);
1786-
int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1787-
static_assert(QK8_1 % WARP_SIZE == 0);
1788-
const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1789-
{
1790-
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
1791-
1792-
stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size),
1793-
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
1794-
quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1795-
});
1796-
}
1797-
}
1798-
}
17991660

18001661
static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
18011662
float *dst, const int ncols_x,
@@ -2372,10 +2233,10 @@ static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
23722233
peer_access_enabled = enable_peer_access;
23732234
}
23742235

2236+
template <template <int> typename quantize_f>
23752237
static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
23762238
const ggml_tensor *src1, ggml_tensor *dst,
2377-
ggml_sycl_op_mul_mat_t op,
2378-
const bool convert_src1_to_q8_1) try {
2239+
ggml_sycl_op_mul_mat_t op) try {
23792240

23802241
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
23812242

@@ -2470,6 +2331,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
24702331
}
24712332
}
24722333

2334+
constexpr bool quantize_enabled = !std::is_same_v<quantize_f<QK8_1 / WARP_SIZE>,
2335+
no_quantize_q8_1<QK8_1 / WARP_SIZE>>;
24732336
for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
24742337
if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
24752338
continue;
@@ -2495,20 +2358,19 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
24952358
dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
24962359
}
24972360

2498-
if (convert_src1_to_q8_1) {
2361+
if constexpr(quantize_enabled) {
24992362
dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
25002363

25012364
if (src1_on_device && src1_is_contiguous) {
2502-
bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder;
25032365
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
25042366
/*num_src=*/2, " : converting src1 to Q8_1");
2505-
quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream);
2506-
/*
2507-
DPCT1010:90: SYCL uses exceptions to report errors and does not
2508-
use the error codes. The call was replaced with 0. You need to
2509-
rewrite this code.
2510-
*/
2511-
SYCL_CHECK(0);
2367+
try {
2368+
quantize_row_q8_1_sycl<quantize_f>(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
2369+
} catch (sycl::exception const &exc) {
2370+
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what() << "Exception caught at file:" << __FILE__
2371+
<< ", line:" << __LINE__ << std::endl;
2372+
std::exit(1);
2373+
}
25122374
}
25132375
}
25142376

@@ -2524,11 +2386,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
25242386
// here an event is recorded that signals that the main device has finished calculating the input data
25252387
if (split && used_devices > 1) {
25262388
ggml_sycl_set_device(ctx.device);
2527-
/*
2528-
DPCT1024:91: The original code returned the error code that was further
2529-
consumed by the program logic. This original code was replaced with 0.
2530-
You may need to rewrite the program logic consuming the error code.
2531-
*/
25322389
SYCL_CHECK(CHECK_TRY_ERROR(
25332390
*src0_extra->events[ctx.device][0] =
25342391
ctx.stream()->ext_oneapi_submit_barrier()));
@@ -2552,11 +2409,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
25522409

25532410
// wait for main GPU data if necessary
25542411
if (split && (i != ctx.device || is != 0)) {
2555-
/*
2556-
DPCT1009:163: SYCL uses exceptions to report errors and does not
2557-
use the error codes. The original code was commented out and a
2558-
warning string was inserted. You need to rewrite this code.
2559-
*/
25602412
SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
25612413
{*src0_extra->events[ctx.device][0]})));
25622414
}
@@ -2582,39 +2434,42 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
25822434
// copy src0, src1 to device if necessary
25832435
if (src1_is_contiguous) {
25842436
if (i != ctx.device) {
2585-
if (convert_src1_to_q8_1) {
2437+
if constexpr (quantize_enabled) {
25862438
char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
2587-
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
2588-
src1_ddq_i, src1_ddq_i_source,
2589-
src1_ncols * src1_padded_col_size * q8_1_ts /
2590-
q8_1_bs).wait()));
2439+
SYCL_CHECK(
2440+
CHECK_TRY_ERROR(stream
2441+
->memcpy(src1_ddq_i, src1_ddq_i_source,
2442+
src1_ncols * src1_padded_col_size * q8_1_ts / q8_1_bs)
2443+
.wait()));
25912444
} else {
2592-
25932445
float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
2594-
src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
2446+
src1_ddf_i_source += (i0 * ne11 + src1_col_0) * ne10;
25952447

2596-
SYCL_CHECK(CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream,
2597-
src1_ddf_i, src1_ddf_i_source,
2598-
src1_ncols * ne10 * sizeof(float))));
2448+
SYCL_CHECK(
2449+
CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream, src1_ddf_i, src1_ddf_i_source,
2450+
src1_ncols * ne10 * sizeof(float))));
25992451
}
26002452
}
2601-
} else if (src1_on_device && !src1_is_contiguous) {
2602-
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
2603-
src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
26042453
} else {
2605-
GGML_ABORT("fatal error");
2606-
}
2454+
if (src1_on_device) {
2455+
SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, src1_col_0,
2456+
src1_col_0 + src1_ncols, stream));
2457+
} else {
2458+
GGML_ABORT("src1 is non-contiguous and not on device");
2459+
}
26072460

2608-
if (convert_src1_to_q8_1 && !src1_is_contiguous) {
2609-
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2610-
/*num_src=*/2, " : converting src1 to Q8_1");
2611-
quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream);
2612-
/*
2613-
DPCT1010:92: SYCL uses exceptions to report errors and does
2614-
not use the error codes. The call was replaced with 0. You
2615-
need to rewrite this code.
2616-
*/
2617-
SYCL_CHECK(0);
2461+
if constexpr (quantize_enabled) {
2462+
scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst,
2463+
/*num_src=*/2, " : converting src1 to Q8_1");
2464+
try {
2465+
quantize_row_q8_1_sycl<quantize_q8_1>(src1_ddf_i, src1_ddq_i, ne10, src1_ncols,
2466+
src1_padded_col_size, stream);
2467+
} catch (const sycl::exception & exc) {
2468+
std::cerr << "Quantize_row_q8_1_sycl error" << exc.what()
2469+
<< "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
2470+
std::exit(1);
2471+
}
2472+
}
26182473
}
26192474

26202475
if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
@@ -2626,12 +2481,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
26262481
// do the computation
26272482
SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
26282483
dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
2629-
/*
2630-
DPCT1010:93: SYCL uses exceptions to report errors and does not
2631-
use the error codes. The call was replaced with 0. You need to
2632-
rewrite this code.
2633-
*/
2634-
SYCL_CHECK(0);
26352484

26362485
// copy dst to host or other device if necessary
26372486
if (!dst_on_device) {
@@ -2662,12 +2511,6 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
26622511

26632512
// add event for the main device to wait on until other device is done
26642513
if (split && (i != ctx.device || is != 0)) {
2665-
/*
2666-
DPCT1024:94: The original code returned the error code that
2667-
was further consumed by the program logic. This original
2668-
code was replaced with 0. You may need to rewrite the
2669-
program logic consuming the error code.
2670-
*/
26712514
SYCL_CHECK(CHECK_TRY_ERROR(
26722515
*src0_extra->events[i][is] =
26732516
stream->ext_oneapi_submit_barrier()));
@@ -3351,19 +3194,20 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
33513194
// KQ + KQV multi-batch
33523195
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
33533196
} else if (use_dequantize_mul_mat_vec) {
3354-
constexpr bool convert_src1_to_q8_1 = false;
33553197
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV);
3356-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
3198+
ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec);
33573199
} else if (use_mul_mat_vec_q) {
3358-
constexpr bool convert_src1_to_q8_1 = true;
33593200
opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ);
3360-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
3201+
ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
3202+
if (extra && extra->optimized_feature.reorder) {
3203+
ggml_sycl_op_mul_mat<quantize_and_reorder_q8_1_soa>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
3204+
} else {
3205+
ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q);
3206+
}
33613207
} else if (use_mul_mat_q) {
3362-
constexpr bool convert_src1_to_q8_1 = true;
3363-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
3208+
ggml_sycl_op_mul_mat<quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q);
33643209
} else {
3365-
constexpr bool convert_src1_to_q8_1 = false;
3366-
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
3210+
ggml_sycl_op_mul_mat<no_quantize_q8_1>(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl);
33673211
}
33683212
}
33693213

0 commit comments

Comments
 (0)