Skip to content

Commit c19ec8f

Browse files
committed
vulkan: use q8_1_x4 blocks in mul_mmq shader
1 parent e83d158 commit c19ec8f

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5488,7 +5488,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
54885488
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
54895489

54905490
if (quantize_y) {
5491-
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false);
5491+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
54925492
}
54935493

54945494
if (dryrun) {
@@ -5505,7 +5505,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
55055505
ctx->prealloc_size_x = x_sz_upd;
55065506
}
55075507
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
5508-
ctx->prealloc_size_y = y_sz_upd;
5508+
ctx->prealloc_size_y = CEIL_DIV(y_sz_upd, 128) * 128;
55095509
}
55105510
if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
55115511
ctx->prealloc_size_split_k = split_k_size;
@@ -5577,7 +5577,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
55775577
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
55785578
}
55795579
if (quantize_y) {
5580-
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
5580+
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
55815581
}
55825582

55835583
uint32_t stride_batch_x = ne00*ne01;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
2828
#if defined(A_TYPE_PACKED32)
2929
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
3030
#endif
31-
layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
31+
layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
3232
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
3333

3434
#ifdef MUL_MAT_ID
@@ -98,7 +98,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
9898
#endif
9999

100100
#define LOAD_VEC_A (4 * QUANT_R)
101-
#define LOAD_VEC_B 4
101+
#define LOAD_VEC_B 16
102102

103103
#ifdef MUL_MAT_ID
104104
shared u16vec2 row_ids[4096];
@@ -270,15 +270,22 @@ void main() {
270270
const uint iqs = idx & 0x7;
271271
#else
272272
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
273+
const uint ib_outer = ib / 4;
274+
const uint ib_inner = ib % 4;
275+
273276
const uint iqs = loadr_b;
274277
#endif
275278

276279
const uint buf_ib = loadc_b + l;
277280

278281
if (iqs == 0) {
279-
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
282+
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
280283
}
281-
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
284+
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
285+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
286+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
287+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
288+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
282289
}
283290

284291
barrier();

0 commit comments

Comments
 (0)