Skip to content

Commit 33e5f11

Browse files
committed
opencl: add mul_mat_f32_f32_l4_lm and mul_mat_f16_f32_l4_lm
1 parent 8ad7b3e commit 33e5f11

File tree

4 files changed

+399
-12
lines changed

4 files changed

+399
-12
lines changed

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ set(GGML_OPENCL_KERNELS
8282
mul_mv_q4_0_f32_1d_16x_flat
8383
mul_mv_q6_k
8484
mul_mv_id_q4_0_f32_8x_flat
85+
mul_mm_f32_f32_l4_lm
86+
mul_mm_f16_f32_l4_lm
8587
mul
8688
norm
8789
relu

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 132 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#undef MAX
3434
#define MIN(a, b) ((a) < (b) ? (a) : (b))
3535
#define MAX(a, b) ((a) > (b) ? (a) : (b))
36+
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
3637

3738
#define UNUSED(x) (void)(x)
3839

@@ -396,6 +397,8 @@ struct ggml_backend_opencl_context {
396397
cl_program program_conv_2d_f16_f32;
397398
cl_program program_tsembd;
398399
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
400+
cl_program program_mul_mm_f32_f32_l4_lm;
401+
cl_program program_mul_mm_f16_f32_l4_lm;
399402

400403
cl_kernel kernel_add, kernel_add_row;
401404
cl_kernel kernel_mul, kernel_mul_row;
@@ -450,6 +453,8 @@ struct ggml_backend_opencl_context {
450453
cl_kernel kernel_conv_2d_f16_f32;
451454
cl_kernel kernel_timestep_embedding;
452455
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
456+
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
457+
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
453458

454459
std::vector<ProfilingInfo> profiling_info;
455460

@@ -1040,6 +1045,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
10401045
GGML_LOG_CONT(".");
10411046
}
10421047

1048+
// mul_mm_f32_f32_l4_lm
1049+
{
1050+
#ifdef GGML_OPENCL_EMBED_KERNELS
1051+
const std::string kernel_src {
1052+
#include "mul_mm_f32_f32_l4_lm.cl.h"
1053+
};
1054+
#else
1055+
const std::string kernel_src = read_file("mul_mm_f32_f32_l4_lm.cl");
1056+
#endif
1057+
backend_ctx->program_mul_mm_f32_f32_l4_lm =
1058+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1059+
1060+
CL_CHECK((backend_ctx->kernel_mul_mm_f32_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f32_f32_l4_lm, "kernel_mul_mm_f32_f32_l4_lm", &err), err));
1061+
GGML_LOG_CONT(".");
1062+
}
1063+
1064+
// mul_mm_f16_f32_l4_lm
1065+
{
1066+
#ifdef GGML_OPENCL_EMBED_KERNELS
1067+
const std::string kernel_src {
1068+
#include "mul_mm_f16_f32_l4_lm.cl.h"
1069+
};
1070+
#else
1071+
const std::string kernel_src = read_file("mul_mm_f16_f32_l4_lm.cl");
1072+
#endif
1073+
backend_ctx->program_mul_mm_f16_f32_l4_lm =
1074+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
1075+
1076+
CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_l4_lm, "kernel_mul_mm_f16_f32_l4_lm", &err), err));
1077+
GGML_LOG_CONT(".");
1078+
}
1079+
10431080
// mul
10441081
{
10451082
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -5297,18 +5334,6 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
52975334

52985335
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
52995336

5300-
if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
5301-
src0->ne[1] > 32 && // M > 32
5302-
src1->ne[1] > 32 && // N > 32
5303-
src0->ne[0] > 32 && // K > 32
5304-
src0->ne[2] == 1 && src0->ne[3] == 1 &&
5305-
src1->ne[2] == 1 && src1->ne[3] == 1 &&
5306-
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
5307-
backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
5308-
ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
5309-
return;
5310-
}
5311-
53125337
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
53135338
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
53145339
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
@@ -5655,6 +5680,101 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
56555680
} // if (ne01 && ne1)
56565681
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
56575682

5683+
// GEMM using local memory
5684+
// Current BK = 16, so ne00 % 16 == 0
5685+
if (ggml_is_contiguous(src0) &&
5686+
ggml_is_contiguous(src1) &&
5687+
src1t == GGML_TYPE_F32 &&
5688+
ne00 % 16 == 0 &&
5689+
ne11 > 1) {
5690+
switch(src0t) {
5691+
case GGML_TYPE_F32: {
5692+
kernel = backend_ctx->kernel_mul_mm_f32_f32_l4_lm;
5693+
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
5694+
5695+
int batch_stride_a = ne00*ne01;
5696+
int batch_stride_b = ne10*ne11;
5697+
int batch_stride_d = ne0*ne1;
5698+
5699+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5700+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5701+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
5702+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
5703+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
5704+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
5705+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
5706+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
5707+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
5708+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
5709+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
5710+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
5711+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
5712+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
5713+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
5714+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
5715+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
5716+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
5717+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
5718+
5719+
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
5720+
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
5721+
size_t local_work_size[] = {(size_t)nth0, 1, 1};
5722+
5723+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
5724+
return;
5725+
}
5726+
case GGML_TYPE_F16: {
5727+
kernel = backend_ctx->kernel_mul_mm_f16_f32_l4_lm;
5728+
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
5729+
5730+
int batch_stride_a = ne00*ne01;
5731+
int batch_stride_b = ne10*ne11;
5732+
int batch_stride_d = ne0*ne1;
5733+
5734+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
5735+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
5736+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
5737+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
5738+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
5739+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
5740+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
5741+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
5742+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
5743+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
5744+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
5745+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
5746+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
5747+
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
5748+
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
5749+
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
5750+
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
5751+
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
5752+
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
5753+
5754+
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
5755+
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
5756+
size_t local_work_size[] = {(size_t)nth0, 1, 1};
5757+
5758+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
5759+
return;
5760+
}
5761+
default:
5762+
break;
5763+
}
5764+
}
5765+
5766+
if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
5767+
src0->ne[1] > 32 && // M > 32
5768+
src1->ne[1] > 32 && // N > 32
5769+
src0->ne[0] > 32 && // K > 32
5770+
src0->ne[2] == 1 && src0->ne[3] == 1 &&
5771+
src1->ne[2] == 1 && src1->ne[3] == 1 &&
5772+
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
5773+
backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
5774+
ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
5775+
return;
5776+
}
5777+
56585778
if (!ggml_is_transposed(src0) &&
56595779
!ggml_is_transposed(src1) &&
56605780
src1t == GGML_TYPE_F32 &&
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2+
3+
#define LOAD_VEC_A 4
4+
#define LOAD_VEC_B 4
5+
6+
#define BM 64
7+
#define BN 64
8+
#define BK 16
9+
#define TM 4
10+
#define TN 8
11+
12+
kernel void kernel_mul_mm_f16_f32_l4_lm(
13+
global half4 * src0,
14+
ulong offset0,
15+
global float4 * src1,
16+
ulong offset1,
17+
global float * dst,
18+
ulong offsetd,
19+
20+
int ne00,
21+
int ne01,
22+
int ne02,
23+
int ne11,
24+
int ne12,
25+
26+
int stride_a,
27+
int stride_b,
28+
int stride_d,
29+
30+
int batch_stride_a,
31+
int batch_stride_b,
32+
int batch_stride_d,
33+
34+
int r2,
35+
int r3
36+
) {
37+
src0 = (global half4*)((global char*)src0 + offset0);
38+
src1 = (global float4*)((global char*)src1 + offset1);
39+
dst = (global float*)((global char*)dst + offsetd);
40+
41+
local half buf_a[BM * BK];
42+
local float buf_b[BN * BK];
43+
44+
const int batch_idx = get_global_id(2);
45+
46+
const int i13 = batch_idx / ne12;
47+
const int i12 = batch_idx % ne12;
48+
49+
const int i03 = i13 / r3;
50+
const int i02 = i12 / r2;
51+
52+
const int batch_idx_a = i03 * ne02 + i02;
53+
54+
const int ir = get_group_id(0);
55+
const int ic = get_group_id(1);
56+
57+
const int tid = get_local_id(0);
58+
const int th_r = tid % (BM / TM);
59+
const int th_c = tid / (BM / TM);
60+
61+
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
62+
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
63+
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
64+
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
65+
66+
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
67+
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
68+
69+
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
70+
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
71+
72+
float sums[TM * TN];
73+
half cache_a[TM];
74+
float cache_b[TN];
75+
76+
for (int i = 0; i < TM * TN; i++) {
77+
sums[i] = 0.0f;
78+
}
79+
80+
for (int block = 0; block < ne00; block += BK) {
81+
for (int l = 0; l < BM; l += loadstride_a) {
82+
const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
83+
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
84+
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
85+
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
86+
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
87+
}
88+
89+
for (int l = 0; l < BN; l += loadstride_b) {
90+
const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
91+
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
92+
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
93+
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
94+
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
95+
}
96+
97+
barrier(CLK_LOCAL_MEM_FENCE);
98+
99+
pos_a += BK / LOAD_VEC_A;
100+
pos_b += BK / LOAD_VEC_B;
101+
102+
for (int i = 0; i < BK; i++) {
103+
for (int j = 0; j < TM; j++) {
104+
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
105+
}
106+
for (int j = 0; j < TN; j++) {
107+
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
108+
}
109+
110+
for (int cc = 0; cc < TN; cc++) {
111+
for (int cr = 0; cr < TM; cr++) {
112+
const int sums_idx = cc*TM + cr;
113+
sums[sums_idx] = mad(convert_float(cache_a[cr]), cache_b[cc], sums[sums_idx]);
114+
}
115+
}
116+
}
117+
barrier(CLK_LOCAL_MEM_FENCE);
118+
}
119+
120+
const int dr = ir * BM + th_r * TM;
121+
const int dc = ic * BN + th_c * TN;
122+
123+
const int offsets = batch_idx * batch_stride_d;
124+
125+
for (int cc = 0; cc < TN; cc++) {
126+
for (int cr = 0; cr < TM; cr++) {
127+
if (dr + cr < ne01 && dc + cc < ne11) {
128+
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
129+
}
130+
}
131+
}
132+
}

0 commit comments

Comments
 (0)