Another speed gain for Q4_0 and Q4_1 on Metal (#2375)

* Another speed gain for Q4_0 and Q4_1 on Metal

* Have N_DST, etc., be template parameters

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow 2023-07-25 13:48:29 +03:00 committed by GitHub
parent 129d844c87
commit 9a08eaf3c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -387,87 +387,90 @@ kernel void kernel_rms_norm(
} }
} }
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i]) // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) { // il indicates where the q4 quants begin (0 or QK4_0/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d; float d = qb_curr->d;
float4 acc = 0.f; float2 acc = 0.f;
device uint16_t * qs = ((device uint16_t *)qb_curr + 1); device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
for (int i = 0; i < 16; i+=2) { for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i] * (qs[i / 2] & 0x000F); acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); + yl[i + 1] * (qs[i / 2] & 0x0F00);
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + yl[i + 9] * (qs[i / 2] & 0xF000);
} }
return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f); return d * (sumy * -8.f + acc[0] + acc[1]);
} }
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i]) // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) { // il indicates where the q4 quants begin (0 or QK4_0/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d; float d = qb_curr->d;
float m = qb_curr->m; float m = qb_curr->m;
float4 acc = 0.f; device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
device uint16_t * qs = ((device uint16_t *)qb_curr + 2); float2 acc = 0.f;
for (int i = 0; i < 16; i+=2) { for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i] * (qs[i / 2] & 0x000F); acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); + yl[i + 1] * (qs[i / 2] & 0x0F00);
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + yl[i + 9] * (qs[i / 2] & 0xF000);
} }
return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m; return d * (acc[0] + acc[1]) + sumy * m;
} }
// putting them in the kernel cause a significant performance penalty // putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows #define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
template<typename block_q_type> //Note: This is a template, but strictly speaking it only applies to
// quantizations where the block size is 32. It also does not
// giard against the number of rows not being divisible by
// N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw>
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
uint2 tgpig, uint tiisg, uint sgitg) { uint2 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0; const int nb = ne00/QK4_0;
const int r0 = tgpig.x; const int r0 = tgpig.x;
const int r1 = tgpig.y; const int r1 = tgpig.y;
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; const int first_row = (r0 * nsg + sgitg) * nr;
device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
device const float * y = (device const float *) src1 + r1*ne10; device const float * y = (device const float *) src1 + r1*ne10;
float4 y_curr[8]; // src1 vector cache float yl[16]; // src1 vector cache
float sumf[N_DST]={0.f}, all_sum; float sumf[nr]={0.f};
thread float * yl=(thread float *)y_curr;
// each thread in a SIMD group deals with 1 block. const int ix = tiisg/2;
for (int column = 0; column < nb / N_SIMDWIDTH; column++) { const int il = 8*(tiisg%2);
device const float * yb = y + ix * QK4_0 + il;
// each thread in a SIMD group deals with half a block.
for (int ib = ix; ib < nb; ib += nw/2) {
float sumy = 0; float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) { for (int i = 0; i < 8; i += 2) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i); sumy += yb[i] + yb[i+1];
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; yl[i+0] = yb[i+ 0];
yl[i+1] = yb[i+ 1]/256.f;
sumy += yb[i+16] + yb[i+17];
yl[i+8] = yb[i+16]/16.f;
yl[i+9] = yb[i+17]/4096.f;
} }
for (int row = 0; row < N_DST; row++) { for (int row = 0; row < nr; row++) {
sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl); sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
}
} }
// from now loads two rows every time and 16 blocks per row yb += QK4_0 * 16;
int ir = tiisg / (N_SIMDWIDTH / 2);
int ib = tiisg % (N_SIMDWIDTH / 2);
for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
} }
for (int row = 0; row < N_DST; row+=2) { for (int row = 0; row < nr; ++row) {
if (nb_start + ib < nb) { const float tot = simd_sum(sumf[row]);
sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl); if (tiisg == 0 && first_row + row < ne01) {
} dst[r1*ne0 + first_row + row] = tot;
}
}
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
} }
} }
} }
@ -483,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32(
uint2 tgpig[[threadgroup_position_in_grid]], uint2 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
} }
kernel void kernel_mul_mat_q4_1_f32( kernel void kernel_mul_mat_q4_1_f32(
@ -497,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32(
uint2 tgpig[[threadgroup_position_in_grid]], uint2 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
} }
kernel void kernel_mul_mat_f16_f32( kernel void kernel_mul_mat_f16_f32(