From 2f668c330e979ff7995e0f42e800a7795a32ec16 Mon Sep 17 00:00:00 2001 From: mkiol Date: Wed, 4 Oct 2023 10:57:55 +0200 Subject: [PATCH] whisper : add abort callback (#1335) --- whisper.cpp | 50 +++++++++++++++++++++++++++++++------------------- whisper.h | 9 +++++++++ 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 916883c..403c2d0 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -125,9 +125,17 @@ static void byteswap_tensor(ggml_tensor * tensor) { // ggml helpers // -static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { +static void ggml_graph_compute_helper( + std::vector & buf, + ggml_cgraph * graph, + int n_threads, + whisper_abort_callback abort_callback, + void * abort_callback_data) { struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + plan.abort_callback = abort_callback; + plan.abort_callback_data = abort_callback_data; + if (plan.work_size > 0) { buf.resize(plan.work_size); plan.work_data = buf.data(); @@ -1922,7 +1930,9 @@ static bool whisper_encode_internal( whisper_context & wctx, whisper_state & wstate, const int mel_offset, - const int n_threads) { + const int n_threads, + whisper_abort_callback abort_callback, + void * abort_callback_data) { const int64_t t_start_us = ggml_time_us(); // conv @@ -1936,7 +1946,7 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); if (!whisper_encode_external(wstate)) { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); } } @@ -1955,10 +1965,10 @@ static bool whisper_encode_internal( ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); ggml_metal_graph_compute(wstate.ctx_metal, gf); } else { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); } #else - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); #endif } @@ -1977,10 +1987,10 @@ static bool whisper_encode_internal( ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); ggml_metal_graph_compute(wstate.ctx_metal, gf); } else { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); } #else - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); #endif } @@ -2346,7 +2356,9 @@ static bool whisper_decode_internal( const whisper_token * tokens, const int n_tokens, const int n_past, - const int n_threads) { + const int n_threads, + whisper_abort_callback abort_callback, + void * abort_callback_data) { const int64_t t_start_us = ggml_time_us(); const auto & model = wctx.model; @@ -2375,10 +2387,10 @@ static bool whisper_decode_internal( ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); ggml_metal_graph_compute(wstate.ctx_metal, gf); } else { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); } #else - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads); + ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); #endif } @@ -3290,7 +3302,7 @@ int whisper_set_mel( } int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) { - if (!whisper_encode_internal(*ctx, *state, offset, n_threads)) { + if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) { log("%s: failed to eval\n", __func__); return -1; } @@ -3299,7 +3311,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state } int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { - if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads)) { + if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) { log("%s: failed to eval\n", __func__); return -1; } @@ -3310,7 +3322,7 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { const int selected_decoder_id = 0; - if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { + if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { log("%s: failed to eval\n", __func__); return 1; } @@ -3327,7 +3339,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i return false; } - if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { + if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { log("%s: failed to eval\n", __func__); return 1; } @@ -4594,7 +4606,7 @@ int whisper_full_with_state( } // encode audio features starting at offset seek - if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads)) { + if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { log("%s: failed to encode\n", __func__); return -6; } @@ -4677,7 +4689,7 @@ int whisper_full_with_state( } WHISPER_PRINT_DEBUG("\n\n"); - if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { + if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { log("%s: failed to decode\n", __func__); return -7; } @@ -4901,7 +4913,7 @@ int whisper_full_with_state( //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); - if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { + if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { log("%s: failed to decode\n", __func__); return -8; } @@ -5473,12 +5485,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { double tsum = 0.0; // heat-up - ggml_graph_compute_helper(work, &gf, n_threads); + ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr); for (int i = 0; i < n_max; ++i) { const int64_t t0 = ggml_time_us(); - ggml_graph_compute_helper(work, &gf, n_threads); + ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr); const int64_t t1 = ggml_time_us(); diff --git a/whisper.h b/whisper.h index 6c0efc1..c3118c9 100644 --- a/whisper.h +++ b/whisper.h @@ -334,6 +334,11 @@ extern "C" { // If it returns false, the computation is aborted typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data); + // Abort callback + // If not NULL, called before ggml computation + // If it returns true, the computation is aborted + typedef bool (*whisper_abort_callback)(void * user_data); + // Logits filter callback // Can be used to modify the logits before sampling // If not NULL, called after applying temperature to logits @@ -428,6 +433,10 @@ extern "C" { whisper_encoder_begin_callback encoder_begin_callback; void * encoder_begin_callback_user_data; + // called each time before ggml computation starts + whisper_abort_callback abort_callback; + void * abort_callback_user_data; + // called by each decoder to filter obtained logits whisper_logits_filter_callback logits_filter_callback; void * logits_filter_callback_user_data;