From 5f14ee0b0cd06f1c4790e6123df4b38ace637e88 Mon Sep 17 00:00:00 2001 From: Jack Mousseau Date: Mon, 29 Jan 2024 01:22:23 -0800 Subject: [PATCH] metal : add debug capture backend function (ggml/694) Co-authored-by: Georgi Gerganov --- ggml-metal.h | 3 +++ ggml-metal.m | 40 ++++++++++++++++++++++++++++++++++------ 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/ggml-metal.h b/ggml-metal.h index df83a1807..a5c542189 100644 --- a/ggml-metal.h +++ b/ggml-metal.h @@ -57,6 +57,9 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(voi // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family); +// capture all command buffers committed the next time `ggml_backend_graph_compute` is called +GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend); + #ifdef __cplusplus } #endif diff --git a/ggml-metal.m b/ggml-metal.m index 1b02493f8..7e148b6bd 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -168,6 +168,8 @@ struct ggml_metal_context { bool support_simdgroup_reduction; bool support_simdgroup_mm; + + bool should_capture_next_compute; }; // MSL code @@ -687,6 +689,20 @@ static bool ggml_metal_graph_compute( const int n_cb = ctx->n_cb; const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; + const bool should_capture = ctx->should_capture_next_compute; + if (should_capture) { + ctx->should_capture_next_compute = false; + + MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; + descriptor.captureObject = ctx->queue; + + NSError * error = nil; + if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { + GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); + GGML_ASSERT(!"capture failed"); + } + } + id command_buffer_builder[n_cb]; for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; @@ -695,6 +711,7 @@ static bool ggml_metal_graph_compute( // enqueue the command buffers in order to specify their execution order [command_buffer enqueue]; } + const id *command_buffers = command_buffer_builder; dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) { @@ -741,9 +758,9 @@ static bool ggml_metal_graph_compute( GGML_ASSERT(!"unsupported op"); } -#ifndef GGML_METAL_NDEBUG - [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]]; -#endif + if (should_capture) { + [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]]; + } const int64_t ne00 = src0 ? src0->ne[0] : 0; const int64_t ne01 = src0 ? src0->ne[1] : 0; @@ -2218,9 +2235,9 @@ static bool ggml_metal_graph_compute( } } -#ifndef GGML_METAL_NDEBUG - [encoder popDebugGroup]; -#endif + if (should_capture) { + [encoder popDebugGroup]; + } } [encoder endEncoding]; @@ -2242,6 +2259,10 @@ static bool ggml_metal_graph_compute( } } + if (should_capture) { + [[MTLCaptureManager sharedCaptureManager] stopCapture]; + } + return true; } @@ -2613,6 +2634,13 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; } +void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context; + ctx->should_capture_next_compute = true; +} + GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {