diff --git a/ggml-metal.m b/ggml-metal.m index 0e9b56aa3..814851203 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -57,6 +57,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_q5_k); GGML_METAL_DECL_KERNEL(get_rows_q6_k); GGML_METAL_DECL_KERNEL(rms_norm); + GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); @@ -66,8 +67,10 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); GGML_METAL_DECL_KERNEL(rope); + GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); + GGML_METAL_DECL_KERNEL(cpy_f16_f16); #undef GGML_METAL_DECL_KERNEL }; @@ -162,6 +165,7 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(get_rows_q5_k); GGML_METAL_ADD_KERNEL(get_rows_q6_k); GGML_METAL_ADD_KERNEL(rms_norm); + GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); @@ -171,8 +175,10 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); GGML_METAL_ADD_KERNEL(rope); + GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); + GGML_METAL_ADD_KERNEL(cpy_f16_f16); #undef GGML_METAL_ADD_KERNEL } @@ -735,6 +741,65 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case GGML_OP_NORM: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + const float eps = 1e-5f; + + const int nth = 256; + + [encoder setComputePipelineState:ctx->pipeline_norm]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ALIBI: + { + GGML_ASSERT((src0t == GGML_TYPE_F32)); + const int n_past = ((int32_t *) src1->data)[0]; + const int n_head = ((int32_t *) src1->data)[1]; + const float max_bias = ((float *) src1->data)[2]; + if (__builtin_popcount(n_head) != 1) { + GGML_ASSERT(false && "only power-of-two n_head implemented"); + } + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + [encoder setComputePipelineState:ctx->pipeline_alibi_f32]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; + const int nth = 32; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_ROPE: { if (encoder == nil) { @@ -788,6 +853,14 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false && "not implemented"); }; } break; + case GGML_TYPE_F16: + { + switch (dstt) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break; + case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break; + default: GGML_ASSERT(false && "not implemented"); + }; + } break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 09e12a879..d1e49222d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1( (device float *) ((device char *) dst + i*nb1), ne00); } +kernel void kernel_norm( + device const void * src0, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant float & eps, + threadgroup float * sum [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); + // MEAN + // parallel sum + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + sum[tpitg] += x[i00]; + } + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + // broadcast + if (tpitg == 0) { + sum[0] /= ne00; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + const float mean = sum[0]; + + // recenter + device float * y = dst + tgpig*ne00; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = x[i00] - mean; + } + + // VARIANCE + // parallel sum + sum[tpitg] = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + sum[tpitg] += y[i00] * y[i00]; + } + // reduce + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = ntg/2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + // broadcast + if (tpitg == 0) { + sum[0] /= ne00; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + const float variance = sum[0]; + + const float scale = 1.0f/sqrt(variance + eps); + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + y[i00] = y[i00] * scale; + } +} + + kernel void kernel_rms_norm( device const void * src0, device float * dst, @@ -485,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32( } } +kernel void kernel_alibi_f32( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & m0, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + float m_k = pow(m0, i2 + 1); + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1); + } +} + kernel void kernel_rope( device const void * src0, device float * dst, @@ -540,6 +648,47 @@ kernel void kernel_rope( } } +kernel void kernel_cpy_f16_f16( + device const half * src0, + device half * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + kernel void kernel_cpy_f32_f16( device const float * src0, device half * dst,