5
5
6
6
#ifdef USE_SUBGROUPS
7
7
#extension GL_KHR_shader_subgroup_basic : require
8
- #extension GL_KHR_shader_subgroup_clustered : require
8
+ #extension GL_KHR_shader_subgroup_arithmetic : require
9
9
10
10
#define INVOCATION_ID gl_SubgroupInvocationID.x
11
11
#else
12
12
#define INVOCATION_ID gl_LocalInvocationID.x
13
13
#endif
14
14
15
15
#define MMQ
16
- #define B_TYPE block_q8_1_x4_packed128
16
+ #define B_TYPE block_q8_1_x4
17
17
18
18
#include "mul_mat_vec_base.comp"
19
19
20
20
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
21
21
22
- #define K_PER_ITER 32
23
-
24
- const uint GROUP_SIZE = 8;
25
- const uint GROUPS_PER_WARP = (BLOCK_SIZE / GROUP_SIZE);
22
+ #define K_PER_ITER 8
26
23
27
24
#include "mul_mmq_funcs.comp"
28
25
29
- uint a_offset, b_offset, d_offset, y_offset ;
26
+ uint a_offset, b_offset, d_offset;
30
27
31
- #ifdef USE_SUBGROUPS
32
- void reduce_result_grouped(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid_in_group) {
33
- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
34
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
35
- temp[j][n] = subgroupClusteredAdd(temp[j][n], GROUP_SIZE);
36
- }
37
- }
38
-
39
- if (tid_in_group == 0) {
40
- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
41
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
42
- data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
43
- }
44
- }
45
- }
46
- }
47
- #else
48
- void reduce_result_grouped(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid_in_group) {
49
- const uint tid = INVOCATION_ID;
50
- // sum up partial sums and write back result
51
- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
52
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
53
- tmpsh[j][n][tid] = temp[j][n];
54
- }
55
- }
56
- barrier();
57
- if (tid_in_group < 4) {
58
- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
59
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
60
- tmpsh[j][n][tid] += tmpsh[j][n][tid + 4];
61
- }
62
- }
63
- }
64
- barrier();
65
- if (tid_in_group < 2) {
66
- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
67
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
68
- tmpsh[j][n][tid] += tmpsh[j][n][tid + 2];
69
- }
70
- }
71
- }
72
- barrier();
73
- if (tid_in_group == 0) {
74
- [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
75
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
76
- data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][tid] + tmpsh[j][n][tid + 1]);
77
- }
78
- }
79
- }
80
- }
81
- #endif
82
-
83
- ivec4 cache_b_qs[2];
28
+ int32_t cache_b_qs[2];
84
29
vec2 cache_b_ds;
85
30
86
- void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid_in_group , const uint i) {
31
+ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid , const uint i) {
87
32
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
88
- const uint col = i*GROUP_SIZE + K_PER_ITER*tid_in_group ;
33
+ const uint col = i*BLOCK_SIZE + tid*K_PER_ITER ;
89
34
90
35
// Preload data_b block
91
36
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
37
+ const uint b_qs_idx = tid % 4;
92
38
const uint b_block_idx_outer = b_block_idx / 4;
93
39
const uint b_block_idx_inner = b_block_idx % 4;
94
40
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
95
- cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 2];
96
- cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 2 + 1];
41
+
42
+ #if QUANT_R == 2
43
+ cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
44
+ cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
45
+ #else
46
+ cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
47
+ cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
48
+ #endif
97
49
98
50
uint ibi = first_row*p.ncols;
99
51
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
@@ -102,71 +54,36 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
102
54
103
55
int32_t q_sum = 0;
104
56
#if QUANT_R == 2
105
- i32vec2 data_a_qs = repack(a_block_idx, 0);
106
- q_sum += dotPacked4x8EXT(data_a_qs.x,
107
- cache_b_qs[0].x);
108
- q_sum += dotPacked4x8EXT(data_a_qs.y,
109
- cache_b_qs[1].x);
110
- data_a_qs = repack(a_block_idx, 1);
111
- q_sum += dotPacked4x8EXT(data_a_qs.x,
112
- cache_b_qs[0].y);
113
- q_sum += dotPacked4x8EXT(data_a_qs.y,
114
- cache_b_qs[1].y);
115
- data_a_qs = repack(a_block_idx, 2);
57
+ const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
116
58
q_sum += dotPacked4x8EXT(data_a_qs.x,
117
- cache_b_qs[0].z );
59
+ cache_b_qs[0]);
118
60
q_sum += dotPacked4x8EXT(data_a_qs.y,
119
- cache_b_qs[1].z);
120
- data_a_qs = repack(a_block_idx, 3);
121
- q_sum += dotPacked4x8EXT(data_a_qs.x,
122
- cache_b_qs[0].w);
123
- q_sum += dotPacked4x8EXT(data_a_qs.y,
124
- cache_b_qs[1].w);
61
+ cache_b_qs[1]);
125
62
#else
126
- int32_t data_a_qs = repack(a_block_idx, 0);
127
- q_sum += dotPacked4x8EXT(data_a_qs,
128
- cache_b_qs[0].x);
129
- data_a_qs = repack(a_block_idx, 1);
130
- q_sum += dotPacked4x8EXT(data_a_qs,
131
- cache_b_qs[0].y);
132
- data_a_qs = repack(a_block_idx, 2);
133
- q_sum += dotPacked4x8EXT(data_a_qs,
134
- cache_b_qs[0].z);
135
- data_a_qs = repack(a_block_idx, 3);
63
+ int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
136
64
q_sum += dotPacked4x8EXT(data_a_qs,
137
- cache_b_qs[0].w );
138
- data_a_qs = repack(a_block_idx, 4 );
65
+ cache_b_qs[0]);
66
+ data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1 );
139
67
q_sum += dotPacked4x8EXT(data_a_qs,
140
- cache_b_qs[1].x);
141
- data_a_qs = repack(a_block_idx, 5);
142
- q_sum += dotPacked4x8EXT(data_a_qs,
143
- cache_b_qs[1].y);
144
- data_a_qs = repack(a_block_idx, 6);
145
- q_sum += dotPacked4x8EXT(data_a_qs,
146
- cache_b_qs[1].z);
147
- data_a_qs = repack(a_block_idx, 7);
148
- q_sum += dotPacked4x8EXT(data_a_qs,
149
- cache_b_qs[1].w);
68
+ cache_b_qs[1]);
150
69
#endif
151
70
152
71
#if QUANT_AUXF == 1
153
- temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds);
72
+ temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4 );
154
73
#else
155
- temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds);
74
+ temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4 );
156
75
#endif
157
76
}
158
77
}
159
78
}
160
79
161
80
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
162
- const uint tid_in_group = INVOCATION_ID % GROUP_SIZE ;
81
+ const uint tid = INVOCATION_ID;
163
82
164
83
get_offsets(a_offset, b_offset, d_offset);
165
84
a_offset /= QUANT_K;
166
85
b_offset /= QUANT_K_Q8_1;
167
86
168
- y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
169
-
170
87
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
171
88
172
89
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
@@ -175,8 +92,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
175
92
}
176
93
}
177
94
178
- uint num_iters = p.ncols / (K_PER_ITER * GROUP_SIZE );
179
- if (num_iters * K_PER_ITER * GROUP_SIZE + K_PER_ITER*tid_in_group < p.ncols) {
95
+ uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE );
96
+ if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
180
97
num_iters++;
181
98
}
182
99
int unroll_count = 4;
@@ -186,7 +103,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
186
103
while (i < unrolled_iters) {
187
104
// Manually partially unroll the loop
188
105
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
189
- iter(temp, first_row, num_rows, tid_in_group , i*K_PER_ITER);
106
+ iter(temp, first_row, num_rows, tid , i*K_PER_ITER);
190
107
i++;
191
108
}
192
109
}
@@ -205,22 +122,20 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
205
122
while (i < unrolled_iters) {
206
123
// Manually partially unroll the loop
207
124
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
208
- iter(temp, first_row, num_rows, tid_in_group , i*K_PER_ITER);
125
+ iter(temp, first_row, num_rows, tid , i*K_PER_ITER);
209
126
i++;
210
127
}
211
128
}
212
129
while (i < num_iters) {
213
- iter(temp, first_row, num_rows, tid_in_group , i*K_PER_ITER);
130
+ iter(temp, first_row, num_rows, tid , i*K_PER_ITER);
214
131
i++;
215
132
}
216
133
217
- reduce_result_grouped (temp, d_offset, first_row, num_rows, tid_in_group );
134
+ reduce_result (temp, d_offset, first_row, num_rows, tid );
218
135
}
219
136
220
137
void main() {
221
- const uint group_id = INVOCATION_ID / GROUP_SIZE;
222
- // 8 threads work together on a NUM_ROWS * NUM_COLS block/slice
223
- const uint first_row = NUM_ROWS * (GROUPS_PER_WARP * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z) + group_id);
138
+ const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
224
139
225
140
// do NUM_ROWS at a time, unless there aren't enough remaining rows
226
141
if (first_row + NUM_ROWS <= p.stride_d) {
0 commit comments