From 6af0bab347574fcb18081bda12a5d9de04dfe367 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 3 Sep 2023 09:00:27 +0300 Subject: [PATCH] ~4-5% improvement for Q8_0 TG on metal --- ggml-metal.metal | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index e2eb5ba35..3fa311b40 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -443,6 +443,8 @@ kernel void kernel_mul_mat_q4_1_f32( mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } +#define NB_Q8_0 8 + kernel void kernel_mul_mat_q8_0_f32( device const void * src0, device const float * src1, @@ -471,30 +473,30 @@ kernel void kernel_mul_mat_q8_0_f32( device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[16]; + float yl[NB_Q8_0]; float sumf[nr]={0.f}; - const int ix = tiisg/2; - const int il = tiisg%2; + const int ix = tiisg/4; + const int il = tiisg%4; - device const float * yb = y + ix * QK8_0 + 16*il; + device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { - for (int i = 0; i < 16; ++i) { + // each thread in a SIMD group deals with NB_Q8_0 quants at a time + for (int ib = ix; ib < nb; ib += nw/4) { + for (int i = 0; i < NB_Q8_0; ++i) { yl[i] = yb[i]; } for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + 16*il; + device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; float sumq = 0.f; - for (int iq = 0; iq < 16; ++iq) { + for (int iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } sumf[row] += sumq*x[ib+row*nb].d; } - yb += QK8_0 * 16; + yb += NB_Q8_0 * nw; } for (int row = 0; row < nr; ++row) {