Skip to content

Commit cac45d3

Browse files
committed
vulkan: use q8_1_x4 blocks in mul_mmq shader
1 parent cf38145 commit cac45d3

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5576,7 +5576,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
55765576
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
55775577

55785578
if (quantize_y) {
5579-
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false);
5579+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
55805580
}
55815581

55825582
if (dryrun) {
@@ -5593,7 +5593,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
55935593
ctx->prealloc_size_x = x_sz_upd;
55945594
}
55955595
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
5596-
ctx->prealloc_size_y = y_sz_upd;
5596+
ctx->prealloc_size_y = CEIL_DIV(y_sz_upd, 128) * 128;
55975597
}
55985598
if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
55995599
ctx->prealloc_size_split_k = split_k_size;
@@ -5665,7 +5665,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
56655665
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
56665666
}
56675667
if (quantize_y) {
5668-
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
5668+
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
56695669
}
56705670

56715671
uint32_t stride_batch_x = ne00*ne01;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
2828
#if defined(A_TYPE_PACKED32)
2929
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
3030
#endif
31-
layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
31+
layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
3232
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
3333

3434
#ifdef MUL_MAT_ID
@@ -98,7 +98,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
9898
#endif
9999

100100
#define LOAD_VEC_A (4 * QUANT_R)
101-
#define LOAD_VEC_B 4
101+
#define LOAD_VEC_B 16
102102

103103
#ifdef MUL_MAT_ID
104104
shared u16vec2 row_ids[4096];
@@ -270,15 +270,22 @@ void main() {
270270
const uint iqs = idx & 0x7;
271271
#else
272272
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
273+
const uint ib_outer = ib / 4;
274+
const uint ib_inner = ib % 4;
275+
273276
const uint iqs = loadr_b;
274277
#endif
275278

276279
const uint buf_ib = loadc_b + l;
277280

278281
if (iqs == 0) {
279-
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
282+
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
280283
}
281-
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
284+
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
285+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
286+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
287+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
288+
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
282289
}
283290

284291
barrier();

0 commit comments

Comments
 (0)