From 469c9addef75893e6be12edda852d12e840bf064 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 24 Oct 2023 09:46:50 +0300 Subject: [PATCH] metal : handle ggml_scale for n%4 != 0 (close #3754) ggml-ci --- ggml-metal.m | 18 +++++++++++++----- ggml-metal.metal | 10 +++++++++- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index c908106be..c1901dca7 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -62,6 +62,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul); GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast GGML_METAL_DECL_KERNEL(scale); + GGML_METAL_DECL_KERNEL(scale_4); GGML_METAL_DECL_KERNEL(silu); GGML_METAL_DECL_KERNEL(relu); GGML_METAL_DECL_KERNEL(gelu); @@ -249,6 +250,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul); GGML_METAL_ADD_KERNEL(mul_row); GGML_METAL_ADD_KERNEL(scale); + GGML_METAL_ADD_KERNEL(scale_4); GGML_METAL_ADD_KERNEL(silu); GGML_METAL_ADD_KERNEL(relu); GGML_METAL_ADD_KERNEL(gelu); @@ -347,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul); GGML_METAL_DEL_KERNEL(mul_row); GGML_METAL_DEL_KERNEL(scale); + GGML_METAL_DEL_KERNEL(scale_4); GGML_METAL_DEL_KERNEL(silu); GGML_METAL_DEL_KERNEL(relu); GGML_METAL_DEL_KERNEL(gelu); @@ -923,15 +926,20 @@ void ggml_metal_graph_compute( const float scale = *(const float *) src1->data; - [encoder setComputePipelineState:ctx->pipeline_scale]; + int64_t n = ggml_nelements(dst); + + if (n % 4 == 0) { + n /= 4; + [encoder setComputePipelineState:ctx->pipeline_scale_4]; + } else { + [encoder setComputePipelineState:ctx->pipeline_scale]; + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(gf->nodes[i])) { diff --git a/ggml-metal.metal b/ggml-metal.metal index 69fc71362..f4b460564 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -125,9 +125,17 @@ kernel void kernel_mul_row( } kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_scale_4( device const float4 * src0, device float4 * dst, - constant float & scale, + constant float & scale, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * scale; }