From 396ebd1e80c7953e271371a771dc5249c4811813 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 13 Jan 2024 18:03:45 +0200 Subject: [PATCH] metal : refactor kernel loading code (llama/4794) * metal : detect more GPU families * metal : refactor kernel loading * metal : set kernel family requirements * metal : fix kernel init + fix compile options * metal : take into account simdgroup reduction support * metal : print only skipped kernels * metal : fix check for simdgroup reduction support * metal : check for Metal 3 * metal : free allocations * metal : normalize encoder:setComputePipelineStatus calls ggml-ci * metal : fix Metal3 family check ggml-ci * metal : check for simdgroup matrix mul. feature ggml-ci --- ggml-metal.m | 1048 +++++++++++++++++++++++++------------------------- 1 file changed, 530 insertions(+), 518 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index c036240..6c28a7e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -26,6 +26,8 @@ #define GGML_MAX_CONCUR (2*GGML_DEFAULT_GRAPH_SIZE) +#define GGML_METAL_MAX_KERNELS 256 + struct ggml_metal_buffer { const char * name; @@ -35,6 +37,134 @@ struct ggml_metal_buffer { id metal; }; +struct ggml_metal_kernel { + id function; + id pipeline; +}; + +enum ggml_metal_kernel_type { + GGML_METAL_KERNEL_TYPE_ADD, + GGML_METAL_KERNEL_TYPE_ADD_ROW, + GGML_METAL_KERNEL_TYPE_MUL, + GGML_METAL_KERNEL_TYPE_MUL_ROW, + GGML_METAL_KERNEL_TYPE_DIV, + GGML_METAL_KERNEL_TYPE_DIV_ROW, + GGML_METAL_KERNEL_TYPE_SCALE, + GGML_METAL_KERNEL_TYPE_SCALE_4, + GGML_METAL_KERNEL_TYPE_TANH, + GGML_METAL_KERNEL_TYPE_RELU, + GGML_METAL_KERNEL_TYPE_GELU, + GGML_METAL_KERNEL_TYPE_GELU_QUICK, + GGML_METAL_KERNEL_TYPE_SILU, + GGML_METAL_KERNEL_TYPE_SOFT_MAX, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, + GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, + GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, + GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, + GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, + GGML_METAL_KERNEL_TYPE_RMS_NORM, + GGML_METAL_KERNEL_TYPE_GROUP_NORM, + GGML_METAL_KERNEL_TYPE_NORM, + GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, + //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, + //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, + //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, + GGML_METAL_KERNEL_TYPE_ROPE_F32, + GGML_METAL_KERNEL_TYPE_ROPE_F16, + GGML_METAL_KERNEL_TYPE_ALIBI_F32, + GGML_METAL_KERNEL_TYPE_IM2COL_F16, + GGML_METAL_KERNEL_TYPE_UPSCALE_F32, + GGML_METAL_KERNEL_TYPE_PAD_F32, + GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, + GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, + GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + GGML_METAL_KERNEL_TYPE_CPY_F32_F16, + GGML_METAL_KERNEL_TYPE_CPY_F32_F32, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, + //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, + //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, + GGML_METAL_KERNEL_TYPE_CPY_F16_F16, + GGML_METAL_KERNEL_TYPE_CPY_F16_F32, + GGML_METAL_KERNEL_TYPE_CONCAT, + GGML_METAL_KERNEL_TYPE_SQR, + GGML_METAL_KERNEL_TYPE_SUM_ROWS, + + GGML_METAL_KERNEL_TYPE_COUNT +}; + struct ggml_metal_context { int n_cb; @@ -50,134 +180,13 @@ struct ggml_metal_context { int n_buffers; struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; + struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS]; + int concur_list[GGML_MAX_CONCUR]; int concur_list_len; - // custom kernels -#define GGML_METAL_DECL_KERNEL(name) \ - id function_##name; \ - id pipeline_##name - - GGML_METAL_DECL_KERNEL(add); - GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast - GGML_METAL_DECL_KERNEL(mul); - GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast - GGML_METAL_DECL_KERNEL(div); - GGML_METAL_DECL_KERNEL(div_row); - GGML_METAL_DECL_KERNEL(scale); - GGML_METAL_DECL_KERNEL(scale_4); - GGML_METAL_DECL_KERNEL(tanh); - GGML_METAL_DECL_KERNEL(relu); - GGML_METAL_DECL_KERNEL(gelu); - GGML_METAL_DECL_KERNEL(gelu_quick); - GGML_METAL_DECL_KERNEL(silu); - GGML_METAL_DECL_KERNEL(soft_max); - GGML_METAL_DECL_KERNEL(soft_max_4); - GGML_METAL_DECL_KERNEL(diag_mask_inf); - GGML_METAL_DECL_KERNEL(diag_mask_inf_8); - GGML_METAL_DECL_KERNEL(get_rows_f32); - GGML_METAL_DECL_KERNEL(get_rows_f16); - GGML_METAL_DECL_KERNEL(get_rows_q4_0); - GGML_METAL_DECL_KERNEL(get_rows_q4_1); - GGML_METAL_DECL_KERNEL(get_rows_q5_0); - GGML_METAL_DECL_KERNEL(get_rows_q5_1); - GGML_METAL_DECL_KERNEL(get_rows_q8_0); - GGML_METAL_DECL_KERNEL(get_rows_q2_K); - GGML_METAL_DECL_KERNEL(get_rows_q3_K); - GGML_METAL_DECL_KERNEL(get_rows_q4_K); - GGML_METAL_DECL_KERNEL(get_rows_q5_K); - GGML_METAL_DECL_KERNEL(get_rows_q6_K); - GGML_METAL_DECL_KERNEL(get_rows_i32); - GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs); - GGML_METAL_DECL_KERNEL(get_rows_iq2_xs); - GGML_METAL_DECL_KERNEL(rms_norm); - GGML_METAL_DECL_KERNEL(group_norm); - GGML_METAL_DECL_KERNEL(norm); - GGML_METAL_DECL_KERNEL(mul_mv_f32_f32); - GGML_METAL_DECL_KERNEL(mul_mv_f16_f16); - GGML_METAL_DECL_KERNEL(mul_mv_f16_f32); - GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row); - GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); - GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32); - GGML_METAL_DECL_KERNEL(mul_mv_iq2_xs_f32); - GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32); - //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16); - GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32); - //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row); - //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4); - GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32); - GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32); - GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32); - GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32); - GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32); - GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32); - GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32); - GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xs_f32); - GGML_METAL_DECL_KERNEL(mul_mm_f32_f32); - GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32); - GGML_METAL_DECL_KERNEL(mul_mm_iq2_xs_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32); - GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xs_f32); - GGML_METAL_DECL_KERNEL(rope_f32); - GGML_METAL_DECL_KERNEL(rope_f16); - GGML_METAL_DECL_KERNEL(alibi_f32); - GGML_METAL_DECL_KERNEL(im2col_f16); - GGML_METAL_DECL_KERNEL(upscale_f32); - GGML_METAL_DECL_KERNEL(pad_f32); - GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc); - GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc); - GGML_METAL_DECL_KERNEL(leaky_relu_f32); - GGML_METAL_DECL_KERNEL(cpy_f32_f16); - GGML_METAL_DECL_KERNEL(cpy_f32_f32); - GGML_METAL_DECL_KERNEL(cpy_f32_q8_0); - GGML_METAL_DECL_KERNEL(cpy_f32_q4_0); - GGML_METAL_DECL_KERNEL(cpy_f32_q4_1); - //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0); - //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1); - GGML_METAL_DECL_KERNEL(cpy_f16_f16); - GGML_METAL_DECL_KERNEL(cpy_f16_f32); - GGML_METAL_DECL_KERNEL(concat); - GGML_METAL_DECL_KERNEL(sqr); - GGML_METAL_DECL_KERNEL(sum_rows); - -#undef GGML_METAL_DECL_KERNEL + bool support_simdgroup_reduction; + bool support_simdgroup_mm; }; // MSL code @@ -298,19 +307,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { return NULL; } - MTLCompileOptions* options = nil; + // dictionary of preprocessor macros + NSMutableDictionary * prep = [NSMutableDictionary dictionary]; + #ifdef GGML_QKK_64 - options = [MTLCompileOptions new]; - options.preprocessorMacros = @{ @"QK_K" : @(64) }; + prep[@"QK_K"] = @(64); #endif - // try to disable fast-math - // NOTE: this seems to have no effect whatsoever - // instead, in order to disable fast-math, we have to build default.metallib from the command line - // using xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air - // and go through the "pre-compiled library found" path above + + MTLCompileOptions* options = [MTLCompileOptions new]; + options.preprocessorMacros = prep; + //[options setFastMathEnabled:false]; ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + + [options release]; + [prep release]; } if (error) { @@ -323,16 +335,41 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { // print MTL GPU family: GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]); + const NSInteger MTLGPUFamilyMetal3 = 5001; + // determine max supported GPU family // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf - for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { - if ([ctx->device supportsFamily:i]) { - GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); - break; + { + for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { + if ([ctx->device supportsFamily:i]) { + GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); + break; + } + } + + for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) { + if ([ctx->device supportsFamily:i]) { + GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i); + break; + } + } + + for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) { + if ([ctx->device supportsFamily:i]) { + GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i); + break; + } } } + ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7]; + ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3]; + + ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7]; + + GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false"); + GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false"); GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6); if (ctx->device.maxTransferRate != 0) { @@ -346,141 +383,152 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { { NSError * error = nil; + for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) { + ctx->kernels[i].function = nil; + ctx->kernels[i].pipeline = nil; + } + /* - GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \ - (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \ - (int) ctx->pipeline_##name.threadExecutionWidth); \ + GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ */ -#define GGML_METAL_ADD_KERNEL(name) \ - ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \ - ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \ - if (error) { \ - GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ - return NULL; \ +#define GGML_METAL_ADD_KERNEL(e, name, supported) \ + if (supported) { \ + struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ + kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ + kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \ + GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ + if (error) { \ + GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ + return NULL; \ + } \ + } else { \ + GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \ } - GGML_METAL_ADD_KERNEL(add); - GGML_METAL_ADD_KERNEL(add_row); - GGML_METAL_ADD_KERNEL(mul); - GGML_METAL_ADD_KERNEL(mul_row); - GGML_METAL_ADD_KERNEL(div); - GGML_METAL_ADD_KERNEL(div_row); - GGML_METAL_ADD_KERNEL(scale); - GGML_METAL_ADD_KERNEL(scale_4); - GGML_METAL_ADD_KERNEL(tanh); - GGML_METAL_ADD_KERNEL(relu); - GGML_METAL_ADD_KERNEL(gelu); - GGML_METAL_ADD_KERNEL(gelu_quick); - GGML_METAL_ADD_KERNEL(silu); - GGML_METAL_ADD_KERNEL(soft_max); - GGML_METAL_ADD_KERNEL(soft_max_4); - GGML_METAL_ADD_KERNEL(diag_mask_inf); - GGML_METAL_ADD_KERNEL(diag_mask_inf_8); - GGML_METAL_ADD_KERNEL(get_rows_f32); - GGML_METAL_ADD_KERNEL(get_rows_f16); - GGML_METAL_ADD_KERNEL(get_rows_q4_0); - GGML_METAL_ADD_KERNEL(get_rows_q4_1); - GGML_METAL_ADD_KERNEL(get_rows_q5_0); - GGML_METAL_ADD_KERNEL(get_rows_q5_1); - GGML_METAL_ADD_KERNEL(get_rows_q8_0); - GGML_METAL_ADD_KERNEL(get_rows_q2_K); - GGML_METAL_ADD_KERNEL(get_rows_q3_K); - GGML_METAL_ADD_KERNEL(get_rows_q4_K); - GGML_METAL_ADD_KERNEL(get_rows_q5_K); - GGML_METAL_ADD_KERNEL(get_rows_q6_K); - GGML_METAL_ADD_KERNEL(get_rows_i32); - GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs); - GGML_METAL_ADD_KERNEL(get_rows_iq2_xs); - GGML_METAL_ADD_KERNEL(rms_norm); - GGML_METAL_ADD_KERNEL(group_norm); - GGML_METAL_ADD_KERNEL(norm); - GGML_METAL_ADD_KERNEL(mul_mv_f32_f32); - GGML_METAL_ADD_KERNEL(mul_mv_f16_f16); - GGML_METAL_ADD_KERNEL(mul_mv_f16_f32); - GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row); - GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); - GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32); - GGML_METAL_ADD_KERNEL(mul_mv_iq2_xs_f32); - GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32); - //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16); - GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32); - //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row); - //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4); - GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32); - GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32); - GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32); - GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32); - GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32); - GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xs_f32); - if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { - GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); - GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32); - GGML_METAL_ADD_KERNEL(mul_mm_iq2_xs_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32); - GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xs_f32); - } - GGML_METAL_ADD_KERNEL(rope_f32); - GGML_METAL_ADD_KERNEL(rope_f16); - GGML_METAL_ADD_KERNEL(alibi_f32); - GGML_METAL_ADD_KERNEL(im2col_f16); - GGML_METAL_ADD_KERNEL(upscale_f32); - GGML_METAL_ADD_KERNEL(pad_f32); - GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc); - GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc); - GGML_METAL_ADD_KERNEL(leaky_relu_f32); - GGML_METAL_ADD_KERNEL(cpy_f32_f16); - GGML_METAL_ADD_KERNEL(cpy_f32_f32); - GGML_METAL_ADD_KERNEL(cpy_f32_q8_0); - GGML_METAL_ADD_KERNEL(cpy_f32_q4_0); - GGML_METAL_ADD_KERNEL(cpy_f32_q4_1); - //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0); - //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1); - GGML_METAL_ADD_KERNEL(cpy_f16_f16); - GGML_METAL_ADD_KERNEL(cpy_f16_f32); - GGML_METAL_ADD_KERNEL(concat); - GGML_METAL_ADD_KERNEL(sqr); - GGML_METAL_ADD_KERNEL(sum_rows); + // simd_sum and simd_max requires MTLGPUFamilyApple7 -#undef GGML_METAL_ADD_KERNEL + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } return ctx; @@ -488,137 +536,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); -#define GGML_METAL_DEL_KERNEL(name) \ - [ctx->function_##name release]; \ - [ctx->pipeline_##name release]; - - GGML_METAL_DEL_KERNEL(add); - GGML_METAL_DEL_KERNEL(add_row); - GGML_METAL_DEL_KERNEL(mul); - GGML_METAL_DEL_KERNEL(mul_row); - GGML_METAL_DEL_KERNEL(div); - GGML_METAL_DEL_KERNEL(div_row); - GGML_METAL_DEL_KERNEL(scale); - GGML_METAL_DEL_KERNEL(scale_4); - GGML_METAL_DEL_KERNEL(tanh); - GGML_METAL_DEL_KERNEL(relu); - GGML_METAL_DEL_KERNEL(gelu); - GGML_METAL_DEL_KERNEL(gelu_quick); - GGML_METAL_DEL_KERNEL(silu); - GGML_METAL_DEL_KERNEL(soft_max); - GGML_METAL_DEL_KERNEL(soft_max_4); - GGML_METAL_DEL_KERNEL(diag_mask_inf); - GGML_METAL_DEL_KERNEL(diag_mask_inf_8); - GGML_METAL_DEL_KERNEL(get_rows_f32); - GGML_METAL_DEL_KERNEL(get_rows_f16); - GGML_METAL_DEL_KERNEL(get_rows_q4_0); - GGML_METAL_DEL_KERNEL(get_rows_q4_1); - GGML_METAL_DEL_KERNEL(get_rows_q5_0); - GGML_METAL_DEL_KERNEL(get_rows_q5_1); - GGML_METAL_DEL_KERNEL(get_rows_q8_0); - GGML_METAL_DEL_KERNEL(get_rows_q2_K); - GGML_METAL_DEL_KERNEL(get_rows_q3_K); - GGML_METAL_DEL_KERNEL(get_rows_q4_K); - GGML_METAL_DEL_KERNEL(get_rows_q5_K); - GGML_METAL_DEL_KERNEL(get_rows_q6_K); - GGML_METAL_DEL_KERNEL(get_rows_i32); - GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs); - GGML_METAL_DEL_KERNEL(get_rows_iq2_xs); - GGML_METAL_DEL_KERNEL(rms_norm); - GGML_METAL_DEL_KERNEL(group_norm); - GGML_METAL_DEL_KERNEL(norm); - GGML_METAL_DEL_KERNEL(mul_mv_f32_f32); - GGML_METAL_DEL_KERNEL(mul_mv_f16_f16); - GGML_METAL_DEL_KERNEL(mul_mv_f16_f32); - GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row); - GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); - GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32); - GGML_METAL_DEL_KERNEL(mul_mv_iq2_xs_f32); - GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32); - //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16); - GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32); - //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row); - //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4); - GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32); - GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32); - GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32); - GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32); - GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32); - GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32); - GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xs_f32); - if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { - GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); - GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32); - GGML_METAL_DEL_KERNEL(mul_mm_iq2_xs_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32); - GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xs_f32); - } - GGML_METAL_DEL_KERNEL(rope_f32); - GGML_METAL_DEL_KERNEL(rope_f16); - GGML_METAL_DEL_KERNEL(alibi_f32); - GGML_METAL_DEL_KERNEL(im2col_f16); - GGML_METAL_DEL_KERNEL(upscale_f32); - GGML_METAL_DEL_KERNEL(pad_f32); - GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc); - GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc); - GGML_METAL_DEL_KERNEL(leaky_relu_f32); - GGML_METAL_DEL_KERNEL(cpy_f32_f16); - GGML_METAL_DEL_KERNEL(cpy_f32_f32); - GGML_METAL_DEL_KERNEL(cpy_f32_q8_0); - GGML_METAL_DEL_KERNEL(cpy_f32_q4_0); - GGML_METAL_DEL_KERNEL(cpy_f32_q4_1); - //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0); - //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1); - GGML_METAL_DEL_KERNEL(cpy_f16_f16); - GGML_METAL_DEL_KERNEL(cpy_f16_f32); - GGML_METAL_DEL_KERNEL(concat); - GGML_METAL_DEL_KERNEL(sqr); - GGML_METAL_DEL_KERNEL(sum_rows); - -#undef GGML_METAL_DEL_KERNEL for (int i = 0; i < ctx->n_buffers; ++i) { [ctx->buffers[i].metal release]; } + for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) { + if (ctx->kernels[i].pipeline) { + [ctx->kernels[i].pipeline release]; + } + + if (ctx->kernels[i].function) { + [ctx->kernels[i].function release]; + } + } + [ctx->library release]; [ctx->queue release]; [ctx->device release]; @@ -930,7 +862,7 @@ void ggml_metal_graph_find_concurrency( } } -static bool ggml_metal_supports_op(const struct ggml_tensor * op) { +static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) { switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -956,9 +888,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) { case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SUM_ROWS: + return true; case GGML_OP_SOFT_MAX: case GGML_OP_RMS_NORM: case GGML_OP_GROUP_NORM: + return ctx->support_simdgroup_reduction; case GGML_OP_NORM: case GGML_OP_ALIBI: case GGML_OP_ROPE: @@ -967,9 +901,10 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) { case GGML_OP_PAD: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: + return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: - return true; + return ctx->support_simdgroup_reduction; case GGML_OP_CPY: case GGML_OP_DUP: case GGML_OP_CONT: @@ -1007,6 +942,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) { return false; } } + bool ggml_metal_graph_compute( struct ggml_metal_context * ctx, struct ggml_cgraph * gf) { @@ -1077,7 +1013,7 @@ bool ggml_metal_graph_compute( } break; } - if (!ggml_metal_supports_op(dst)) { + if (!ggml_metal_supports_op(ctx, dst)) { GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); GGML_ASSERT(!"unsupported op"); } @@ -1143,7 +1079,9 @@ bool ggml_metal_graph_compute( { const int64_t nb = ne00; - [encoder setComputePipelineState:ctx->pipeline_concat]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1197,18 +1135,18 @@ bool ggml_metal_graph_compute( nb = ne00 / 4; switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break; - case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break; - case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break; + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; default: GGML_ASSERT(false); } bcast_row = true; } else { switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->pipeline_add; break; - case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break; - case GGML_OP_DIV: pipeline = ctx->pipeline_div; break; + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; default: GGML_ASSERT(false); } } @@ -1275,9 +1213,9 @@ bool ggml_metal_graph_compute( // not sure how to avoid this // TODO: make a simpler cpy_bytes kernel - const int nth = MIN((int) ctx->pipeline_cpy_f32_f32.maxTotalThreadsPerThreadgroup, ne00); + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; - [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; @@ -1297,10 +1235,14 @@ bool ggml_metal_graph_compute( [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } - [encoder setComputePipelineState:ctx->pipeline_add]; + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1330,7 +1272,7 @@ bool ggml_metal_graph_compute( [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - const int nth = MIN((int) ctx->pipeline_add.maxTotalThreadsPerThreadgroup, ne00); + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -1342,13 +1284,16 @@ bool ggml_metal_graph_compute( int64_t n = ggml_nelements(dst); + id pipeline = nil; + if (n % 4 == 0) { n /= 4; - [encoder setComputePipelineState:ctx->pipeline_scale_4]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; } else { - [encoder setComputePipelineState:ctx->pipeline_scale]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; } + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; @@ -1359,7 +1304,9 @@ bool ggml_metal_graph_compute( switch (ggml_get_unary_op(gf->nodes[i])) { case GGML_UNARY_OP_TANH: { - [encoder setComputePipelineState:ctx->pipeline_tanh]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1369,7 +1316,9 @@ bool ggml_metal_graph_compute( } break; case GGML_UNARY_OP_RELU: { - [encoder setComputePipelineState:ctx->pipeline_relu]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1379,7 +1328,9 @@ bool ggml_metal_graph_compute( } break; case GGML_UNARY_OP_GELU: { - [encoder setComputePipelineState:ctx->pipeline_gelu]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1390,7 +1341,9 @@ bool ggml_metal_graph_compute( } break; case GGML_UNARY_OP_GELU_QUICK: { - [encoder setComputePipelineState:ctx->pipeline_gelu_quick]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1401,7 +1354,9 @@ bool ggml_metal_graph_compute( } break; case GGML_UNARY_OP_SILU: { - [encoder setComputePipelineState:ctx->pipeline_silu]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1420,18 +1375,23 @@ bool ggml_metal_graph_compute( { GGML_ASSERT(ggml_is_contiguous(src0)); - [encoder setComputePipelineState:ctx->pipeline_sqr]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; const int64_t n = ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_SUM_ROWS: { GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - [encoder setComputePipelineState:ctx->pipeline_sum_rows]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; @@ -1465,20 +1425,23 @@ bool ggml_metal_graph_compute( { int nth = 32; // SIMD width + id pipeline = nil; + if (ne00%4 == 0) { while (nth < ne00/4 && nth < 256) { nth *= 2; } - [encoder setComputePipelineState:ctx->pipeline_soft_max_4]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline; } else { while (nth < ne00 && nth < 1024) { nth *= 2; } - [encoder setComputePipelineState:ctx->pipeline_soft_max]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; } const float scale = ((float *) dst->op_params)[0]; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1498,11 +1461,15 @@ bool ggml_metal_graph_compute( { const int n_past = ((int32_t *)(dst->op_params))[0]; + id pipeline = nil; + if (ne00%8 == 0) { - [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; } else { - [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; } + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; @@ -1562,23 +1529,28 @@ bool ggml_metal_graph_compute( ne00 % 32 == 0 && ne00 >= 64 && (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + id pipeline = nil; + switch (src0->type) { - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; - case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; - case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break; - case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break; - case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; - case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break; - case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xs_f32]; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1602,12 +1574,14 @@ bool ggml_metal_graph_compute( int nrows = 1; //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + id pipeline = nil; + // use custom matrix x vector kernel switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; nrows = 4; } break; case GGML_TYPE_F16: @@ -1616,16 +1590,16 @@ bool ggml_metal_graph_compute( nth1 = 1; if (src1t == GGML_TYPE_F32) { if (ne11 * ne12 < 4) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; nrows = ne11; } else { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; nrows = 4; } } else { - [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; nrows = 4; } } break; @@ -1633,73 +1607,73 @@ bool ggml_metal_graph_compute( { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { nth0 = 4; //1; nth1 = 8; //32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; } break; case GGML_TYPE_IQ2_XXS: { nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; } break; case GGML_TYPE_IQ2_XS: { nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xs_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; } break; default: { @@ -1712,6 +1686,7 @@ bool ggml_metal_graph_compute( GGML_ASSERT(ne00 >= nth0*nth1); } + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1818,23 +1793,28 @@ bool ggml_metal_graph_compute( if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && ne20 % 32 == 0 && ne20 >= 64 && ne11 > ne11_mm_min) { + + id pipeline = nil; + switch (src2->type) { - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break; - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break; - case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break; - case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break; - case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break; - case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break; - case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break; - case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xs_f32]; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); } + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1874,91 +1854,93 @@ bool ggml_metal_graph_compute( int nrows = 1; //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + id pipeline = nil; + // use custom matrix x vector kernel switch (src2t) { case GGML_TYPE_F32: { GGML_ASSERT(src1t == GGML_TYPE_F32); - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; } break; case GGML_TYPE_F16: { GGML_ASSERT(src1t == GGML_TYPE_F32); nth0 = 32; nth1 = 1; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; } break; case GGML_TYPE_Q4_0: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; } break; case GGML_TYPE_Q4_1: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; } break; case GGML_TYPE_Q5_0: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; } break; case GGML_TYPE_Q5_1: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; } break; case GGML_TYPE_Q8_0: { nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; } break; case GGML_TYPE_Q2_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; } break; case GGML_TYPE_Q3_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; } break; case GGML_TYPE_Q4_K: { nth0 = 4; //1; nth1 = 8; //32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; } break; case GGML_TYPE_Q5_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; } break; case GGML_TYPE_Q6_K: { nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; } break; case GGML_TYPE_IQ2_XXS: { nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; } break; case GGML_TYPE_IQ2_XS: { nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xs_f32]; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; } break; default: { @@ -1973,6 +1955,7 @@ bool ggml_metal_graph_compute( const int64_t _ne1 = 1; // kernels needs a reference in constant memory + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -2040,25 +2023,28 @@ bool ggml_metal_graph_compute( } break; case GGML_OP_GET_ROWS: { + id pipeline = nil; + switch (src0->type) { - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_get_rows_f32]; break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; - case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; - case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break; - case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break; - case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; - case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break; - case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break; - case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xs]; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; default: GGML_ASSERT(false && "not implemented"); } + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -2086,7 +2072,9 @@ bool ggml_metal_graph_compute( nth *= 2; } - [encoder setComputePipelineState:ctx->pipeline_rms_norm]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; @@ -2115,7 +2103,9 @@ bool ggml_metal_graph_compute( // nth *= 2; //} - [encoder setComputePipelineState:ctx->pipeline_group_norm]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; @@ -2137,7 +2127,9 @@ bool ggml_metal_graph_compute( const int nth = MIN(256, ne00); - [encoder setComputePipelineState:ctx->pipeline_norm]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; @@ -2164,7 +2156,9 @@ bool ggml_metal_graph_compute( const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - [encoder setComputePipelineState:ctx->pipeline_alibi_f32]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; @@ -2209,12 +2203,15 @@ bool ggml_metal_graph_compute( memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + id pipeline = nil; + switch (src0->type) { - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break; default: GGML_ASSERT(false); }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -2277,12 +2274,15 @@ bool ggml_metal_graph_compute( const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + id pipeline = nil; + switch (src0->type) { case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break; - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; default: GGML_ASSERT(false); }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; @@ -2305,7 +2305,9 @@ bool ggml_metal_graph_compute( const int sf = dst->op_params[0]; - [encoder setComputePipelineState:ctx->pipeline_upscale_f32]; + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; @@ -2326,7 +2328,7 @@ bool ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; [encoder setBytes:&sf length:sizeof(sf) atIndex:18]; - const int nth = MIN((int) ctx->pipeline_upscale_f32.maxTotalThreadsPerThreadgroup, ne0); + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -2334,7 +2336,9 @@ bool ggml_metal_graph_compute( { GGML_ASSERT(src0->type == GGML_TYPE_F32); - [encoder setComputePipelineState:ctx->pipeline_pad_f32]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; @@ -2367,12 +2371,15 @@ bool ggml_metal_graph_compute( enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + id pipeline = nil; + switch (order) { - case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break; - case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break; + case GGML_SORT_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; + case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; default: GGML_ASSERT(false); }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; @@ -2386,7 +2393,9 @@ bool ggml_metal_graph_compute( float slope; memcpy(&slope, dst->op_params, sizeof(float)); - [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32]; + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; @@ -2403,33 +2412,36 @@ bool ggml_metal_graph_compute( int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); + id pipeline = nil; + switch (src0t) { case GGML_TYPE_F32: { GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); switch (dstt) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; - case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break; - case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break; - case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break; - //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break; - //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; + //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; + //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; default: GGML_ASSERT(false && "not implemented"); }; } break; case GGML_TYPE_F16: { switch (dstt) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break; - case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; default: GGML_ASSERT(false && "not implemented"); }; } break; default: GGML_ASSERT(false && "not implemented"); } + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; @@ -2794,9 +2806,9 @@ static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml } static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - return ggml_metal_supports_op(op); + struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context; - UNUSED(backend); + return ggml_metal_supports_op(metal_ctx, op); } static struct ggml_backend_i ggml_backend_metal_i = {