Skip to content

Commit 5666ed9

Browse files
committed
opencl: add mul_mat_f32_f32_l4_lm and mul_mat_f16_f32_l4_lm
1 parent 38d3af1 commit 5666ed9

File tree

4 files changed

+399
-12
lines changed

4 files changed

+399
-12
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ set(GGML_OPENCL_KERNELS
8282
mul_mv_q4_0_f32_1d_16x_flat
8383
mul_mv_q6_k
8484
mul_mv_id_q4_0_f32_8x_flat
85+
mul_mm_f32_f32_l4_lm
86+
mul_mm_f16_f32_l4_lm
8587
mul
8688
norm
8789
relu

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 132 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#undef MAX
3434
#define MIN(a, b) ((a) < (b) ? (a) : (b))
3535
#define MAX(a, b) ((a) > (b) ? (a) : (b))
36+
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
3637

3738
#define UNUSED(x) (void)(x)
3839

@@ -395,6 +396,8 @@ struct ggml_backend_opencl_context {
395396
cl_program program_conv_2d_f16_f32;
396397
cl_program program_tsembd;
397398
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
399+
cl_program program_mul_mm_f32_f32_l4_lm;
400+
cl_program program_mul_mm_f16_f32_l4_lm;
398401

399402
cl_kernel kernel_add, kernel_add_row;
400403
cl_kernel kernel_mul, kernel_mul_row;
@@ -449,6 +452,8 @@ struct ggml_backend_opencl_context {
449452
cl_kernel kernel_conv_2d_f16_f32;
450453
cl_kernel kernel_timestep_embedding;
451454
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
455+
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
456+
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
452457

453458
std::vector<ProfilingInfo> profiling_info;
454459

@@ -1039,6 +1044,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
10391044
GGML_LOG_CONT(".");
10401045
}
10411046

1047+
// mul_mm_f32_f32_l4_lm
1048+
{
1049+
#ifdef GGML_OPENCL_EMBED_KERNELS
1050+
const std::string kernel_src {
1051+
#include "mul_mm_f32_f32_l4_lm.cl.h"
1052+
};
1053+
#else
1054+
const std::string kernel_src = read_file("mul_mm_f32_f32_l4_lm.cl");
1055+
#endif
1056+
backend_ctx->program_mul_mm_f32_f32_l4_lm =
1057+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1058+
1059+
CL_CHECK((backend_ctx->kernel_mul_mm_f32_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f32_f32_l4_lm, "kernel_mul_mm_f32_f32_l4_lm", &err), err));
1060+
GGML_LOG_CONT(".");
1061+
}
1062+
1063+
// mul_mm_f16_f32_l4_lm
1064+
{
1065+
#ifdef GGML_OPENCL_EMBED_KERNELS
1066+
const std::string kernel_src {
1067+
#include "mul_mm_f16_f32_l4_lm.cl.h"
1068+
};
1069+
#else
1070+
const std::string kernel_src = read_file("mul_mm_f16_f32_l4_lm.cl");
1071+
#endif
1072+
backend_ctx->program_mul_mm_f16_f32_l4_lm =
1073+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1074+
1075+
CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_l4_lm, "kernel_mul_mm_f16_f32_l4_lm", &err), err));
1076+
GGML_LOG_CONT(".");
1077+
}
1078+
10421079
// mul
10431080
{
10441081
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -5139,18 +5176,6 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
51395176

51405177
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
51415178

5142-
if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
5143-
src0->ne[1] > 32 && // M > 32
5144-
src1->ne[1] > 32 && // N > 32
5145-
src0->ne[0] > 32 && // K > 32
5146-
src0->ne[2] == 1 && src0->ne[3] == 1 &&
5147-
src1->ne[2] == 1 && src1->ne[3] == 1 &&
5148-
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
5149-
backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
5150-
ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
5151-
return;
5152-
}
5153-
51545179
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
51555180
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
51565181
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
@@ -5497,6 +5522,101 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
54975522
} // if (ne01 && ne1)
54985523
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
54995524

5525+
// GEMM using local memory
5526+
// Current BK = 16, so ne00 % 16 == 0
5527+
if (ggml_is_contiguous(src0) &&
5528+
ggml_is_contiguous(src1) &&
5529+
src1t == GGML_TYPE_F32 &&
5530+
ne00 % 16 == 0 &&
5531+
ne11 > 1) {
5532+
switch(src0t) {
5533+
case GGML_TYPE_F32: {
5534+
kernel = backend_ctx->kernel_mul_mm_f32_f32_l4_lm;
5535+
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
5536+
5537+
int batch_stride_a = ne00*ne01;
5538+
int batch_stride_b = ne10*ne11;
5539+
int batch_stride_d = ne0*ne1;
5540+
5541+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5542+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5543+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
5544+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
5545+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
5546+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
5547+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
5548+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
5549+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
5550+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
5551+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
5552+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
5553+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
5554+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
5555+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
5556+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
5557+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
5558+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
5559+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
5560+
5561+
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
5562+
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
5563+
size_t local_work_size[] = {(size_t)nth0, 1, 1};
5564+
5565+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
5566+
return;
5567+
}
5568+
case GGML_TYPE_F16: {
5569+
kernel = backend_ctx->kernel_mul_mm_f16_f32_l4_lm;
5570+
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
5571+
5572+
int batch_stride_a = ne00*ne01;
5573+
int batch_stride_b = ne10*ne11;
5574+
int batch_stride_d = ne0*ne1;
5575+
5576+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5577+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5578+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
5579+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
5580+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
5581+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
5582+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
5583+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
5584+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
5585+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
5586+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
5587+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
5588+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
5589+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
5590+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
5591+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
5592+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
5593+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
5594+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
5595+
5596+
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
5597+
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
5598+
size_t local_work_size[] = {(size_t)nth0, 1, 1};
5599+
5600+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
5601+
return;
5602+
}
5603+
default:
5604+
break;
5605+
}
5606+
}
5607+
5608+
if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
5609+
src0->ne[1] > 32 && // M > 32
5610+
src1->ne[1] > 32 && // N > 32
5611+
src0->ne[0] > 32 && // K > 32
5612+
src0->ne[2] == 1 && src0->ne[3] == 1 &&
5613+
src1->ne[2] == 1 && src1->ne[3] == 1 &&
5614+
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
5615+
backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
5616+
ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
5617+
return;
5618+
}
5619+
55005620
if (!ggml_is_transposed(src0) &&
55015621
!ggml_is_transposed(src1) &&
55025622
src1t == GGML_TYPE_F32 &&
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#define LOAD_VEC_A 4
4+
#define LOAD_VEC_B 4
5+
6+
#define BM 64
7+
#define BN 64
8+
#define BK 16
9+
#define TM 4
10+
#define TN 8
11+
12+
kernel void kernel_mul_mm_f16_f32_l4_lm(
13+
global half4 * src0,
14+
ulong offset0,
15+
global float4 * src1,
16+
ulong offset1,
17+
global float * dst,
18+
ulong offsetd,
19+
20+
int ne00,
21+
int ne01,
22+
int ne02,
23+
int ne11,
24+
int ne12,
25+
26+
int stride_a,
27+
int stride_b,
28+
int stride_d,
29+
30+
int batch_stride_a,
31+
int batch_stride_b,
32+
int batch_stride_d,
33+
34+
int r2,
35+
int r3
36+
) {
37+
src0 = (global half4*)((global char*)src0 + offset0);
38+
src1 = (global float4*)((global char*)src1 + offset1);
39+
dst = (global float*)((global char*)dst + offsetd);
40+
41+
local half buf_a[BM * BK];
42+
local float buf_b[BN * BK];
43+
44+
const int batch_idx = get_global_id(2);
45+
46+
const int i13 = batch_idx / ne12;
47+
const int i12 = batch_idx % ne12;
48+
49+
const int i03 = i13 / r3;
50+
const int i02 = i12 / r2;
51+
52+
const int batch_idx_a = i03 * ne02 + i02;
53+
54+
const int ir = get_group_id(0);
55+
const int ic = get_group_id(1);
56+
57+
const int tid = get_local_id(0);
58+
const int th_r = tid % (BM / TM);
59+
const int th_c = tid / (BM / TM);
60+
61+
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
62+
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
63+
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
64+
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
65+
66+
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
67+
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
68+
69+
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
70+
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
71+
72+
float sums[TM * TN];
73+
half cache_a[TM];
74+
float cache_b[TN];
75+
76+
for (int i = 0; i < TM * TN; i++) {
77+
sums[i] = 0.0f;
78+
}
79+
80+
for (int block = 0; block < ne00; block += BK) {
81+
for (int l = 0; l < BM; l += loadstride_a) {
82+
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
83+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
84+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
85+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
86+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
87+
}
88+
89+
for (int l = 0; l < BN; l += loadstride_b) {
90+
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
91+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
92+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
93+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
94+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
95+
}
96+
97+
barrier(CLK_LOCAL_MEM_FENCE);
98+
99+
pos_a += BK / LOAD_VEC_A;
100+
pos_b += BK / LOAD_VEC_B;
101+
102+
for (int i = 0; i < BK; i++) {
103+
for (int j = 0; j < TM; j++) {
104+
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
105+
}
106+
for (int j = 0; j < TN; j++) {
107+
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
108+
}
109+
110+
for (int cc = 0; cc < TN; cc++) {
111+
for (int cr = 0; cr < TM; cr++) {
112+
const int sums_idx = cc*TM + cr;
113+
sums[sums_idx] = mad(convert_float(cache_a[cr]), cache_b[cc], sums[sums_idx]);
114+
}
115+
}
116+
}
117+
barrier(CLK_LOCAL_MEM_FENCE);
118+
}
119+
120+
const int dr = ir * BM + th_r * TM;
121+
const int dc = ic * BN + th_c * TN;
122+
123+
const int offsets = batch_idx * batch_stride_d;
124+
125+
for (int cc = 0; cc < TN; cc++) {
126+
for (int cr = 0; cr < TM; cr++) {
127+
if (dr + cr < ne01 && dc + cc < ne11) {
128+
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
129+
}
130+
}
131+
}
132+
}

0 commit comments

Comments
 (0)