Skip to content

Commit a1053fb

Browse files
committed
Address q2k comments
1 parent d6ee6da commit a1053fb

File tree

1 file changed

+3
-122
lines changed

1 file changed

+3
-122
lines changed

ggml/src/ggml-cpu/arch/x86/repack.cpp

Lines changed: 3 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,61 +1157,8 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
11571157
}
11581158
#else
11591159

1160-
float sumf[8];
1161-
float sum_minf[8];
1162-
int sumi1,sumi2,sumi3,sumi4;
1163-
int sumi;
1164-
1165-
const block_q8_K * a_ptr = (const block_q8_K *)vy;
1166-
for(int x = 0; x < nc / ncols_interleaved; x++) {
1167-
const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
1168-
for (int j = 0; j < ncols_interleaved; j++) {
1169-
sumf[j] = 0.0;
1170-
sum_minf[j] = 0.0;
1171-
}
1172-
for (int l = 0; l < nb; l++) {
1173-
for (int k = 0; k < (qk / (4 * blocklen)); k++) {
1174-
uint8_t *scales_0 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 ;
1175-
uint8_t *scales_1 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 16;
1176-
uint8_t *scales_2 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 32;
1177-
uint8_t *scales_3 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 48;
1178-
for (int j = 0; j < ncols_interleaved; j++) {
1179-
sumi1 = 0;
1180-
sumi2 = 0;
1181-
sumi3 = 0;
1182-
sumi4 = 0;
1183-
sumi = 0;
1184-
int offset = ((k / 2) % 2) + j * 2;
1185-
for (int i = 0; i < blocklen; ++i){
1186-
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
1187-
const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
1188-
const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
1189-
const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
1190-
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]);
1191-
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]);
1192-
sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]);
1193-
sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]);
1194-
1195-
sumi1 = sumi1 * (scales_0[offset] & 0xF);
1196-
sumi2 = sumi2 * (scales_1[offset] & 0xF);
1197-
sumi3 = sumi3 * (scales_2[offset] & 0xF);
1198-
sumi4 = sumi4 * (scales_3[offset] & 0xF);
1199-
sumi += sumi1 + sumi2 + sumi3 + sumi4;
1200-
}
1201-
sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
1202-
}
1203-
}
1204-
for(int sb = 0; sb < 8; sb++) {
1205-
uint8_t *mins = (uint8_t*) b_ptr[l].scales + sb * 16;
1206-
for(int j = 0; j < ncols_interleaved; j++){
1207-
sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
1208-
}
1209-
}
1210-
}
1211-
for (int j = 0; j < ncols_interleaved; j++) {
1212-
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
1213-
}
1214-
}
1160+
ggml_gemv_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
1161+
12151162
#endif
12161163
}
12171164

@@ -6294,74 +6241,8 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
62946241
}
62956242
#else
62966243

6297-
float sumf[4][8];
6298-
float sum_minf[4][8];
6299-
int sumi1, sumi2, sumi3, sumi4;
6300-
int sumi;
6301-
6302-
for (int y = 0; y < nr / 4; y++) {
6303-
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
6304-
for (int x = 0; x < nc / ncols_interleaved; x++) {
6305-
const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
6306-
for (int m = 0; m < 4; m++) {
6307-
for (int j = 0; j < ncols_interleaved; j++) {
6308-
sumf[m][j] = 0.0;
6309-
sum_minf[m][j] = 0.0;
6310-
}
6311-
}
6312-
for (int l = 0; l < nb; l++) {
6313-
for (int k = 0; k < (qk / (4 * blocklen)); k++) {
6314-
6315-
uint8_t *scales_0 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 ;
6316-
uint8_t *scales_1 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 16;
6317-
uint8_t *scales_2 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 32;
6318-
uint8_t *scales_3 = (uint8_t*) b_ptr[l].scales + (k / 4) * 64 + 48;
6319-
for (int m = 0; m < 4; m++) {
6320-
for (int j = 0; j < ncols_interleaved; j++) {
6321-
sumi1 = 0;
6322-
sumi2 = 0;
6323-
sumi3 = 0;
6324-
sumi4 = 0;
6325-
sumi = 0;
6326-
int offset = ((k / 2) % 2) + j * 2;
6327-
for (int i = 0; i < blocklen; ++i){
6328-
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x03);
6329-
const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 0x03);
6330-
const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 0x03);
6331-
const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 0x03);
6332-
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]);
6333-
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
6334-
sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 256]);
6335-
sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 384]);
6336-
sumi1 = sumi1 * (scales_0[offset] & 0xF);
6337-
sumi2 = sumi2 * (scales_1[offset] & 0xF);
6338-
sumi3 = sumi3 * (scales_2[offset] & 0xF);
6339-
sumi4 = sumi4 * (scales_3[offset] & 0xF);
6340-
sumi += sumi1 + sumi2 + sumi3 + sumi4;
6341-
}
6342-
sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
6343-
}
6344-
}
6345-
}
6346-
for(int sb = 0; sb < 8; sb++) {
6347-
uint8_t *mins = (uint8_t*) b_ptr[l].scales + sb * 16;
6348-
for(int m = 0; m < 4; m++) {
6349-
const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
6350-
for(int j = 0; j < ncols_interleaved; j++) {
6351-
int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]);
6352-
sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
6353-
}
6354-
}
6355-
}
6356-
}
6244+
ggml_gemm_q2_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
63576245

6358-
for (int m = 0; m < 4; m++) {
6359-
for (int j = 0; j < ncols_interleaved; j++) {
6360-
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
6361-
}
6362-
}
6363-
}
6364-
}
63656246

63666247
#endif
63676248
}

0 commit comments

Comments
 (0)