diff --git a/ggml-metal.m b/ggml-metal.m index 88e7e1356..d0d23442e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -76,6 +76,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); + GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32); @@ -219,6 +220,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32); @@ -284,6 +286,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(norm); GGML_METAL_DEL_KERNEL(mul_mat_f16_f32); + GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row); GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32); @@ -868,7 +871,11 @@ void ggml_metal_graph_compute( { nth0 = 32; nth1 = 1; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + if (ne11 * ne12 < 4) { + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; + } else { + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + } } break; case GGML_TYPE_Q4_0: { @@ -920,8 +927,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 2; - nth1 = 32; + nth0 = 4; //1; + nth1 = 8; //32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; } break; case GGML_TYPE_Q5_K: @@ -969,9 +976,12 @@ void ggml_metal_graph_compute( [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { + src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + else if (src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q3_K) { #ifdef GGML_QKK_64 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -985,8 +995,8 @@ void ggml_metal_graph_compute( else if (src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { - [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + int64_t ny = (ne11 + 3)/4; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 8cdf0b9d2..3fa311b40 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -133,19 +133,24 @@ kernel void kernel_soft_max( threadgroup_barrier(mem_flags::mem_threadgroup); } - // broadcast - if (tpitg[0] == 0) { - buf[0] = buf[0]; - } + //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of + // the loop, and when that is done, buf[0] has the correct (synchronized) value + //if (tpitg[0] == 0) { + // buf[0] = buf[0]; + //} - threadgroup_barrier(mem_flags::mem_threadgroup); + //threadgroup_barrier(mem_flags::mem_threadgroup); const float max = buf[0]; // parallel sum buf[tpitg[0]] = 0.0f; for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { - buf[tpitg[0]] += exp(psrc0[i00] - max); + const float exp_psrc0 = exp(psrc0[i00] - max); + buf[tpitg[0]] += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // whish to compute it twice. + pdst[i00] = exp_psrc0; } // reduce @@ -157,17 +162,18 @@ kernel void kernel_soft_max( threadgroup_barrier(mem_flags::mem_threadgroup); } - // broadcast - if (tpitg[0] == 0) { - buf[0] = buf[0]; - } + // broadcast - not needed, see above + //// broadcast + //if (tpitg[0] == 0) { + // buf[0] = buf[0]; + //} - threadgroup_barrier(mem_flags::mem_threadgroup); + //threadgroup_barrier(mem_flags::mem_threadgroup); const float sum = buf[0]; for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { - pdst[i00] = exp(psrc0[i00] - max) / sum; + pdst[i00] /= sum; } } @@ -214,25 +220,27 @@ kernel void kernel_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); } - // broadcast - if (tpitg == 0) { - sum[0] /= ne00; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + //// broadcast + //if (tpitg == 0) { + // sum[0] /= ne00; + //} + //threadgroup_barrier(mem_flags::mem_threadgroup); const float mean = sum[0]; - // recenter + // recenter and VARIANCE device float * y = dst + tgpig*ne00; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = x[i00] - mean; - } - - // VARIANCE - // parallel sum sum[tpitg] = 0.0f; for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = x[i00] - mean; sum[tpitg] += y[i00] * y[i00]; } + + //// VARIANCE + //// parallel sum + //sum[tpitg] = 0.0f; + //for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + // sum[tpitg] += y[i00] * y[i00]; + //} // reduce threadgroup_barrier(mem_flags::mem_threadgroup); for (uint i = ntg/2; i > 0; i /= 2) { @@ -241,11 +249,11 @@ kernel void kernel_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); } - // broadcast - if (tpitg == 0) { - sum[0] /= ne00; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + //// broadcast + //if (tpitg == 0) { + // sum[0] /= ne00; + //} + //threadgroup_barrier(mem_flags::mem_threadgroup); const float variance = sum[0]; const float scale = 1.0f/sqrt(variance + eps); @@ -435,6 +443,8 @@ kernel void kernel_mul_mat_q4_1_f32( mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } +#define NB_Q8_0 8 + kernel void kernel_mul_mat_q8_0_f32( device const void * src0, device const float * src1, @@ -463,30 +473,30 @@ kernel void kernel_mul_mat_q8_0_f32( device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[16]; + float yl[NB_Q8_0]; float sumf[nr]={0.f}; - const int ix = tiisg/2; - const int il = tiisg%2; + const int ix = tiisg/4; + const int il = tiisg%4; - device const float * yb = y + ix * QK8_0 + 16*il; + device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { - for (int i = 0; i < 16; ++i) { + // each thread in a SIMD group deals with NB_Q8_0 quants at a time + for (int ib = ix; ib < nb; ib += nw/4) { + for (int i = 0; i < NB_Q8_0; ++i) { yl[i] = yb[i]; } for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + 16*il; + device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; float sumq = 0.f; - for (int iq = 0; iq < 16; ++iq) { + for (int iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } sumf[row] += sumq*x[ib+row*nb].d; } - yb += QK8_0 * 16; + yb += NB_Q8_0 * nw; } for (int row = 0; row < nr; ++row) { @@ -497,6 +507,60 @@ kernel void kernel_mul_mat_q8_0_f32( } } +kernel void kernel_mul_mat_f16_f32_1row( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + const int64_t im = tgpig.z; + + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + if (ne00 < 128) { + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + device const half4 * x4 = (device const half4 *) x; + device const float4 * y4 = (device const float4 *) y; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + +} + +#define N_F16_F32 4 + kernel void kernel_mul_mat_f16_f32( device const char * src0, device const char * src1, @@ -515,55 +579,58 @@ kernel void kernel_mul_mat_f16_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - threadgroup float * sum [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpig[[thread_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; + const int64_t rb = N_F16_F32*tgpig.y; const int64_t im = tgpig.z; - device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); - uint ith = tpitg.x; - uint nth = tptg.x; + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } - sum[ith] = 0.0f; + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - for (int i = ith; i < ne00; i += nth) { - sum[ith] += (float) x[i] * (float) y[i]; + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const half4 * x4 = (device const half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } } - // accumulate the sum from all threads in the threadgroup - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; - } - - // Original implementation. Left behind commented out for now - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (uint i = tptg.x/2; i > 0; i /= 2) { - // if (tpitg.x < i) { - // sum[tpitg.x] += sum[tpitg.x + i]; - // } - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} - // - //if (tpitg.x == 0) { - // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; - //} } kernel void kernel_alibi_f32( @@ -1262,7 +1329,8 @@ kernel void kernel_mul_mat_q4_K_f32( const int r0 = tgpig.x; const int r1 = tgpig.y; const int r2 = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = r0 * N_DST; const int ib_row = first_row * nb; const uint offset0 = r2/gqa*(nb*ne0); device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;