From 2b601702a80e7b30fadc6c886632fb8638111ba9 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 1 Sep 2023 17:06:53 +0300 Subject: [PATCH] Quite significant PP speedup on metal --- ggml-metal.m | 11 +++++++---- ggml-metal.metal | 33 ++++++++++++++++++++++----------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index f9a2228aa..d365ddc67 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -906,8 +906,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 2; - nth1 = 32; + nth0 = 4; //1; + nth1 = 8; //32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; } break; case GGML_TYPE_Q5_K: @@ -955,9 +955,12 @@ void ggml_metal_graph_compute( [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { + src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + else if (src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q3_K) { #ifdef GGML_QKK_64 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -972,7 +975,7 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { //[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, (ne11 + 3)/4, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 171b0bcf9..caf8d37a4 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -505,6 +505,8 @@ kernel void kernel_mul_mat_q8_0_f32( } } +#define N_F16_F32 4 + kernel void kernel_mul_mat_f16_f32( device const char * src0, device const char * src1, @@ -527,20 +529,28 @@ kernel void kernel_mul_mat_f16_f32( uint tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; + const int64_t rb = N_F16_F32*tgpig.y; const int64_t im = tgpig.z; - device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } } } @@ -1241,7 +1251,8 @@ kernel void kernel_mul_mat_q4_K_f32( const int r0 = tgpig.x; const int r1 = tgpig.y; const int r2 = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = r0 * N_DST; const int ib_row = first_row * nb; const uint offset0 = r2/gqa*(nb*ne0); device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;