From e8d91589258f9204397a7ac5f9b3c857835c98f8 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Fri, 1 Sep 2023 11:15:57 +0300 Subject: [PATCH] metal: somewhat faster f16 x f32 matrix multiply kernel (#2951) * Somewhat faster f16 x f32 matrix multiply kernel * Better use 32 thread groups for f16 x f32 --------- Co-authored-by: Iwan Kawrakow --- ggml-metal.m | 2 +- ggml-metal.metal | 38 ++++++++++++++++++++++++++++---------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e929c4b07..8c3c64f53 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -840,7 +840,7 @@ void ggml_metal_graph_compute( switch (src0t) { case GGML_TYPE_F16: { - nth0 = 64; + nth0 = 32; nth1 = 1; [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 82e1a0c7a..02db5323e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -528,24 +528,42 @@ kernel void kernel_mul_mat_f16_f32( device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - sum[tpitg.x] = 0.0f; + uint ith = tpitg.x; + uint nth = tptg.x; - for (int i = tpitg.x; i < ne00; i += tptg.x) { - sum[tpitg.x] += (float) x[i] * (float) y[i]; + sum[ith] = 0.0f; + + for (int i = ith; i < ne00; i += nth) { + sum[ith] += (float) x[i] * (float) y[i]; } // accumulate the sum from all threads in the threadgroup threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = tptg.x/2; i > 0; i /= 2) { - if (tpitg.x < i) { - sum[tpitg.x] += sum[tpitg.x + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; } - - if (tpitg.x == 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; } + + // Original implementation. Left behind commented out for now + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (uint i = tptg.x/2; i > 0; i /= 2) { + // if (tpitg.x < i) { + // sum[tpitg.x] += sum[tpitg.x + i]; + // } + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + // + //if (tpitg.x == 0) { + // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; + //} } kernel void kernel_alibi_f32(