~4-5% improvement for Q8_0 TG on metal

This commit is contained in:
Iwan Kawrakow 2023-09-03 09:00:27 +03:00
parent 363f0bf558
commit 6af0bab347

View file

@ -443,6 +443,8 @@ kernel void kernel_mul_mat_q4_1_f32(
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
} }
#define NB_Q8_0 8
kernel void kernel_mul_mat_q8_0_f32( kernel void kernel_mul_mat_q8_0_f32(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
@ -471,30 +473,30 @@ kernel void kernel_mul_mat_q8_0_f32(
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[16]; float yl[NB_Q8_0];
float sumf[nr]={0.f}; float sumf[nr]={0.f};
const int ix = tiisg/2; const int ix = tiisg/4;
const int il = tiisg%2; const int il = tiisg%4;
device const float * yb = y + ix * QK8_0 + 16*il; device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
// each thread in a SIMD group deals with half a block. // each thread in a SIMD group deals with NB_Q8_0 quants at a time
for (int ib = ix; ib < nb; ib += nw/2) { for (int ib = ix; ib < nb; ib += nw/4) {
for (int i = 0; i < 16; ++i) { for (int i = 0; i < NB_Q8_0; ++i) {
yl[i] = yb[i]; yl[i] = yb[i];
} }
for (int row = 0; row < nr; row++) { for (int row = 0; row < nr; row++) {
device const int8_t * qs = x[ib+row*nb].qs + 16*il; device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
float sumq = 0.f; float sumq = 0.f;
for (int iq = 0; iq < 16; ++iq) { for (int iq = 0; iq < NB_Q8_0; ++iq) {
sumq += qs[iq] * yl[iq]; sumq += qs[iq] * yl[iq];
} }
sumf[row] += sumq*x[ib+row*nb].d; sumf[row] += sumq*x[ib+row*nb].d;
} }
yb += QK8_0 * 16; yb += NB_Q8_0 * nw;
} }
for (int row = 0; row < nr; ++row) { for (int row = 0; row < nr; ++row) {