@@ -1157,61 +1157,8 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
1157
1157
}
1158
1158
#else
1159
1159
1160
- float sumf[8];
1161
- float sum_minf[8];
1162
- int sumi1,sumi2,sumi3,sumi4;
1163
- int sumi;
1164
-
1165
- const block_q8_K * a_ptr = (const block_q8_K *)vy;
1166
- for(int x = 0; x < nc / ncols_interleaved; x++) {
1167
- const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
1168
- for (int j = 0; j < ncols_interleaved; j++) {
1169
- sumf[j] = 0.0;
1170
- sum_minf[j] = 0.0;
1171
- }
1172
- for (int l = 0; l < nb; l++) {
1173
- for (int k = 0; k < (qk / (4 * blocklen)); k++) {
1174
- uint8_t *scales_0 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 ;
1175
- uint8_t *scales_1 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 16;
1176
- uint8_t *scales_2 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 32;
1177
- uint8_t *scales_3 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 48;
1178
- for (int j = 0; j < ncols_interleaved; j++) {
1179
- sumi1 = 0;
1180
- sumi2 = 0;
1181
- sumi3 = 0;
1182
- sumi4 = 0;
1183
- sumi = 0;
1184
- int offset = ((k / 2) % 2) + j * 2;
1185
- for (int i = 0; i < blocklen; ++i){
1186
- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
1187
- const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
1188
- const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
1189
- const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
1190
- sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]);
1191
- sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]);
1192
- sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]);
1193
- sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]);
1194
-
1195
- sumi1 = sumi1 * (scales_0[offset] & 0xF);
1196
- sumi2 = sumi2 * (scales_1[offset] & 0xF);
1197
- sumi3 = sumi3 * (scales_2[offset] & 0xF);
1198
- sumi4 = sumi4 * (scales_3[offset] & 0xF);
1199
- sumi += sumi1 + sumi2 + sumi3 + sumi4;
1200
- }
1201
- sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
1202
- }
1203
- }
1204
- for(int sb = 0; sb < 8; sb++) {
1205
- uint8_t *mins = (uint8_t*) b_ptr[l].scales + sb * 16;
1206
- for(int j = 0; j < ncols_interleaved; j++){
1207
- sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
1208
- }
1209
- }
1210
- }
1211
- for (int j = 0; j < ncols_interleaved; j++) {
1212
- s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
1213
- }
1214
- }
1160
+ ggml_gemv_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1161
+
1215
1162
#endif
1216
1163
}
1217
1164
@@ -6294,74 +6241,8 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
6294
6241
}
6295
6242
#else
6296
6243
6297
- float sumf[4][8];
6298
- float sum_minf[4][8];
6299
- int sumi1, sumi2, sumi3, sumi4;
6300
- int sumi;
6301
-
6302
- for (int y = 0; y < nr / 4; y++) {
6303
- const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
6304
- for (int x = 0; x < nc / ncols_interleaved; x++) {
6305
- const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
6306
- for (int m = 0; m < 4; m++) {
6307
- for (int j = 0; j < ncols_interleaved; j++) {
6308
- sumf[m][j] = 0.0;
6309
- sum_minf[m][j] = 0.0;
6310
- }
6311
- }
6312
- for (int l = 0; l < nb; l++) {
6313
- for (int k = 0; k < (qk / (4 * blocklen)); k++) {
6314
-
6315
- uint8_t *scales_0 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 ;
6316
- uint8_t *scales_1 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 16;
6317
- uint8_t *scales_2 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 32;
6318
- uint8_t *scales_3 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 48;
6319
- for (int m = 0; m < 4; m++) {
6320
- for (int j = 0; j < ncols_interleaved; j++) {
6321
- sumi1 = 0;
6322
- sumi2 = 0;
6323
- sumi3 = 0;
6324
- sumi4 = 0;
6325
- sumi = 0;
6326
- int offset = ((k / 2) % 2) + j * 2;
6327
- for (int i = 0; i < blocklen; ++i){
6328
- const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x03);
6329
- const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 0x03);
6330
- const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 0x03);
6331
- const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 0x03);
6332
- sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]);
6333
- sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
6334
- sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 256]);
6335
- sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 384]);
6336
- sumi1 = sumi1 * (scales_0[offset] & 0xF);
6337
- sumi2 = sumi2 * (scales_1[offset] & 0xF);
6338
- sumi3 = sumi3 * (scales_2[offset] & 0xF);
6339
- sumi4 = sumi4 * (scales_3[offset] & 0xF);
6340
- sumi += sumi1 + sumi2 + sumi3 + sumi4;
6341
- }
6342
- sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
6343
- }
6344
- }
6345
- }
6346
- for(int sb = 0; sb < 8; sb++) {
6347
- uint8_t *mins = (uint8_t*) b_ptr[l].scales + sb * 16;
6348
- for(int m = 0; m < 4; m++) {
6349
- const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
6350
- for(int j = 0; j < ncols_interleaved; j++) {
6351
- int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]);
6352
- sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
6353
- }
6354
- }
6355
- }
6356
- }
6244
+ ggml_gemm_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
6357
6245
6358
- for (int m = 0; m < 4; m++) {
6359
- for (int j = 0; j < ncols_interleaved; j++) {
6360
- s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
6361
- }
6362
- }
6363
- }
6364
- }
6365
6246
6366
6247
#endif
6367
6248
}
0 commit comments