From 289313716ff7ccf6aee284f686a0fe8cbc7714af Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 3 Jan 2024 11:35:46 +0200 Subject: [PATCH] metal : add kernel_get_rows_i32 ggml-ci --- ggml-metal.m | 4 ++++ ggml-metal.metal | 29 +++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/ggml-metal.m b/ggml-metal.m index 7a369b55e..7aa92c14c 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -87,6 +87,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_q4_K); GGML_METAL_DECL_KERNEL(get_rows_q5_K); GGML_METAL_DECL_KERNEL(get_rows_q6_K); + GGML_METAL_DECL_KERNEL(get_rows_i32); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(group_norm); GGML_METAL_DECL_KERNEL(norm); @@ -377,6 +378,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(get_rows_q4_K); GGML_METAL_ADD_KERNEL(get_rows_q5_K); GGML_METAL_ADD_KERNEL(get_rows_q6_K); + GGML_METAL_ADD_KERNEL(get_rows_i32); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(group_norm); GGML_METAL_ADD_KERNEL(norm); @@ -499,6 +501,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(get_rows_q4_K); GGML_METAL_DEL_KERNEL(get_rows_q5_K); GGML_METAL_DEL_KERNEL(get_rows_q6_K); + GGML_METAL_DEL_KERNEL(get_rows_i32); GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(group_norm); GGML_METAL_DEL_KERNEL(norm); @@ -1978,6 +1981,7 @@ void ggml_metal_graph_compute( case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break; case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; + case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 9aa7b502a..a7d3f9efa 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3829,6 +3829,35 @@ kernel void kernel_get_rows_f16( } } +kernel void kernel_get_rows_i32( + device const void * src0, + device const char * src1, + device int32_t * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + + #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B #define BLOCK_SIZE_K 32