From 363f0bf5580d9d59aaabfb227f6595f4806b32ab Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 2 Sep 2023 18:14:41 +0300 Subject: [PATCH] Massive improvement for TG for fp16 --- ggml-metal.metal | 73 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 62b222aa4..e2eb5ba35 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -534,14 +534,27 @@ kernel void kernel_mul_mat_f16_f32_1row( device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; + if (ne00 < 128) { + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + device const half4 * x4 = (device const half4 *) x; + device const float4 * y4 = (device const float4 *) y; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } } #define N_F16_F32 4 @@ -573,22 +586,46 @@ kernel void kernel_mul_mat_f16_f32( device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } } + } else { + device const half4 * x4 = (device const half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } } }