diff --git a/ggml-metal.m b/ggml-metal.m index 4267db9be..f9a2228aa 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -971,7 +971,7 @@ void ggml_metal_graph_compute( else if (src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; + //[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } diff --git a/ggml-metal.metal b/ggml-metal.metal index 8cdf0b9d2..aeb33c581 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -515,11 +515,8 @@ kernel void kernel_mul_mat_f16_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - threadgroup float * sum [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpig[[thread_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -528,42 +525,16 @@ kernel void kernel_mul_mat_f16_f32( device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - uint ith = tpitg.x; - uint nth = tptg.x; - - sum[ith] = 0.0f; - - for (int i = ith; i < ne00; i += nth) { - sum[ith] += (float) x[i] * (float) y[i]; + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; } - // accumulate the sum from all threads in the threadgroup - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } - // Original implementation. Left behind commented out for now - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (uint i = tptg.x/2; i > 0; i /= 2) { - // if (tpitg.x < i) { - // sum[tpitg.x] += sum[tpitg.x + i]; - // } - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} - // - //if (tpitg.x == 0) { - // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; - //} } kernel void kernel_alibi_f32(