From 940cdb13964a563d86c7dc6e160a43ec89b8bb2e Mon Sep 17 00:00:00 2001 From: mkiol Date: Sun, 8 Oct 2023 16:22:24 +0200 Subject: [PATCH] whisper : abort callback improvements (#1345) * whisper : initialize abort_callback to null * whisper : add example how to use abort_callback --- examples/main/main.cpp | 16 ++++++++++++++-- whisper.cpp | 3 +++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 60c1cca..cdd16ac 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -944,8 +944,9 @@ int main(int argc, char ** argv) { wparams.progress_callback_user_data = &user_data; } - // example for abort mechanism - // in this example, we do not abort the processing, but we could if the flag is set to true + // examples for abort mechanism + // in examples below, we do not abort the processing, but we could if the flag is set to true + // the callback is called before every encoder run - if it returns false, the processing is aborted { static bool is_aborted = false; // NOTE: this should be atomic to avoid data race @@ -957,6 +958,17 @@ int main(int argc, char ** argv) { wparams.encoder_begin_callback_user_data = &is_aborted; } + // the callback is called before every computation - if it returns true, the computation is aborted + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + wparams.abort_callback = [](void * user_data) { + bool is_aborted = *(bool*)user_data; + return is_aborted; + }; + wparams.abort_callback_user_data = &is_aborted; + } + if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); return 10; diff --git a/whisper.cpp b/whisper.cpp index 403c2d0..ccac6aa 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3773,6 +3773,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.encoder_begin_callback =*/ nullptr, /*.encoder_begin_callback_user_data =*/ nullptr, + /*.abort_callback =*/ nullptr, + /*.abort_callback_user_data =*/ nullptr, + /*.logits_filter_callback =*/ nullptr, /*.logits_filter_callback_user_data =*/ nullptr, };