From f75e1197f15dea464f6076cd8099438b7a61cd91 Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Fri, 9 Feb 2024 10:42:27 +0100 Subject: [PATCH] ggml : add abort_callback for cpu backend (ggml/725) * a way to use abort_callback with the cpu backend * whisper update --- ggml-backend.c | 26 ++++++++++++++++++++++---- ggml-backend.h | 5 +++-- ggml.c | 2 +- ggml.h | 9 +++++++-- whisper.cpp | 8 ++++---- whisper.h | 7 +------ 6 files changed, 38 insertions(+), 19 deletions(-) diff --git a/ggml-backend.c b/ggml-backend.c index 0764dfe..532da8e 100644 --- a/ggml-backend.c +++ b/ggml-backend.c @@ -653,6 +653,9 @@ struct ggml_backend_cpu_context { int n_threads; void * work_data; size_t work_size; + + ggml_abort_callback abort_callback; + void * abort_callback_data; }; GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) { @@ -691,6 +694,9 @@ GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(gg cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); } + cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback; + cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data; + return cpu_plan; } @@ -721,9 +727,11 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size); cpu_ctx->work_size = cplan.work_size; } - cplan.work_data = cpu_ctx->work_data; + cplan.abort_callback = cpu_ctx->abort_callback; + cplan.abort_callback_data = cpu_ctx->abort_callback_data; + ggml_graph_compute(cgraph, &cplan); return true; } @@ -759,9 +767,11 @@ static struct ggml_backend_i cpu_backend_i = { ggml_backend_t ggml_backend_cpu_init(void) { struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context)); - ctx->n_threads = GGML_DEFAULT_N_THREADS; - ctx->work_data = NULL; - ctx->work_size = 0; + ctx->n_threads = GGML_DEFAULT_N_THREADS; + ctx->work_data = NULL; + ctx->work_size = 0; + ctx->abort_callback = NULL; + ctx->abort_callback_data = NULL; ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend)); @@ -783,6 +793,14 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { ctx->n_threads = n_threads; } +void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = abort_callback_data; +} + GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size); } diff --git a/ggml-backend.h b/ggml-backend.h index 8b8160f..282b3a9 100644 --- a/ggml-backend.h +++ b/ggml-backend.h @@ -83,8 +83,9 @@ extern "C" { GGML_API ggml_backend_t ggml_backend_cpu_init(void); - GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend); - GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads); + GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend); + GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads); + GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); // Create a backend buffer from an existing pointer GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); diff --git a/ggml.c b/ggml.c index a7a9ea3..3499b73 100644 --- a/ggml.c +++ b/ggml.c @@ -16560,7 +16560,7 @@ struct ggml_compute_state_shared { atomic_int node_n; // active graph node atomic_int node_task; // active graph node task phase - bool (*abort_callback)(void * data); // abort ggml_graph_compute when true + ggml_abort_callback abort_callback; // abort ggml_graph_compute when true void * abort_callback_data; }; diff --git a/ggml.h b/ggml.h index bf782e6..e20b14f 100644 --- a/ggml.h +++ b/ggml.h @@ -567,6 +567,11 @@ extern "C" { static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); + // Abort callback + // If not NULL, called before ggml computation + // If it returns true, the computation is aborted + typedef bool (*ggml_abort_callback)(void * data); + // the compute plan that needs to be prepared for ggml_graph_compute() // since https://github.com/ggerganov/ggml/issues/287 struct ggml_cplan { @@ -576,8 +581,8 @@ extern "C" { int n_threads; // abort ggml_graph_compute when true - bool (*abort_callback)(void * data); - void * abort_callback_data; + ggml_abort_callback abort_callback; + void * abort_callback_data; }; enum ggml_cgraph_eval_order { diff --git a/whisper.cpp b/whisper.cpp index 59d5cff..28e3804 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -156,11 +156,11 @@ static bool ggml_graph_compute_helper( struct ggml_cgraph * graph, std::vector & buf, int n_threads, - whisper_abort_callback abort_callback, + ggml_abort_callback abort_callback, void * abort_callback_data) { struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); - plan.abort_callback = abort_callback; + plan.abort_callback = abort_callback; plan.abort_callback_data = abort_callback_data; if (plan.work_size > 0) { @@ -2130,7 +2130,7 @@ static bool whisper_encode_internal( whisper_state & wstate, const int mel_offset, const int n_threads, - whisper_abort_callback abort_callback, + ggml_abort_callback abort_callback, void * abort_callback_data) { const int64_t t_start_us = ggml_time_us(); @@ -2561,7 +2561,7 @@ static bool whisper_decode_internal( whisper_state & wstate, const whisper_batch & batch, const int n_threads, - whisper_abort_callback abort_callback, + ggml_abort_callback abort_callback, void * abort_callback_data) { const int64_t t_start_us = ggml_time_us(); diff --git a/whisper.h b/whisper.h index d571a12..a5371eb 100644 --- a/whisper.h +++ b/whisper.h @@ -412,11 +412,6 @@ extern "C" { // If it returns false, the computation is aborted typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data); - // Abort callback - // If not NULL, called before ggml computation - // If it returns true, the computation is aborted - typedef bool (*whisper_abort_callback)(void * user_data); - // Logits filter callback // Can be used to modify the logits before sampling // If not NULL, called after applying temperature to logits @@ -513,7 +508,7 @@ extern "C" { void * encoder_begin_callback_user_data; // called each time before ggml computation starts - whisper_abort_callback abort_callback; + ggml_abort_callback abort_callback; void * abort_callback_user_data; // called by each decoder to filter obtained logits