diff --git a/ggml-metal.m b/ggml-metal.m index af16540..98f0db6 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -24,10 +24,7 @@ #define UNUSED(x) (void)(x) -#define GGML_METAL_MAX_KERNELS 256 - struct ggml_metal_kernel { - id function; id pipeline; }; @@ -159,11 +156,10 @@ struct ggml_metal_context { id device; id queue; - id library; dispatch_queue_t d_queue; - struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS]; + struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; bool support_simdgroup_reduction; bool support_simdgroup_mm; @@ -248,6 +244,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { ctx->queue = [ctx->device newCommandQueue]; ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); + id metal_library; + // load library { NSBundle * bundle = nil; @@ -262,7 +260,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { // pre-compiled library found NSURL * libURL = [NSURL fileURLWithPath:libPath]; GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); - ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + metal_library = [ctx->device newLibraryWithURL:libURL error:&error]; if (error) { GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; @@ -304,7 +302,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { //[options setFastMathEnabled:false]; - ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + metal_library = [ctx->device newLibraryWithSource:src options:options error:&error]; if (error) { GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; @@ -371,8 +369,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { { NSError * error = nil; - for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) { - ctx->kernels[i].function = nil; + for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { ctx->kernels[i].pipeline = nil; } @@ -384,10 +381,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { #define GGML_METAL_ADD_KERNEL(e, name, supported) \ if (supported) { \ struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ - kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ - kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \ + id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ + kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \ + [metal_function release]; \ if (error) { \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ + [metal_library release]; \ return NULL; \ } \ } else { \ @@ -516,23 +515,17 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } + [metal_library release]; return ctx; } static void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); - for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) { - if (ctx->kernels[i].pipeline) { - [ctx->kernels[i].pipeline release]; - } - - if (ctx->kernels[i].function) { - [ctx->kernels[i].function release]; - } + for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { + [ctx->kernels[i].pipeline release]; } - [ctx->library release]; [ctx->queue release]; [ctx->device release];