From b8065d90f5fdcdb445a8fb3f4717cba54c332cac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 16 Dec 2022 19:43:16 +0200 Subject: [PATCH] main : add "--prompt" command line argument (#90) This allows to provide an initial prompt to be used at the start of the processing. --- examples/main/main.cpp | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 4071bd2..3ef576d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -73,8 +73,9 @@ struct whisper_params { bool print_colors = false; bool no_timestamps = false; - std::string language = "en"; - std::string model = "models/ggml-base.en.bin"; + std::string language = "en"; + std::string prompt = ""; + std::string model = "models/ggml-base.en.bin"; std::vector fname_inp = {}; }; @@ -113,6 +114,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if ( arg == "--prompt") { params.prompt = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_inp.push_back(argv[++i]); } else { @@ -150,6 +152,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); fprintf(stderr, "\n"); @@ -462,6 +465,22 @@ int main(int argc, char ** argv) { return 3; } + // initial prompt + std::vector prompt_tokens; + + if (params.prompt.size() > 0) { + prompt_tokens.resize(1024); + prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size())); + + fprintf(stderr, "\n"); + fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str()); + fprintf(stderr, "initial tokens: [ "); + for (int i = 0; i < (int) prompt_tokens.size(); ++i) { + fprintf(stderr, "%d ", prompt_tokens[i]); + } + fprintf(stderr, "]\n"); + } + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; @@ -577,7 +596,6 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); } - // run the inference { whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); @@ -599,6 +617,9 @@ int main(int argc, char ** argv) { wparams.speed_up = params.speed_up; + wparams.prompt_tokens = prompt_tokens.size() == 0 ? nullptr : prompt_tokens.data(); + wparams.prompt_n_tokens = prompt_tokens.size() == 0 ? 0 : prompt_tokens.size(); + whisper_print_user_data user_data = { ¶ms, &pcmf32s }; // this callback is called on each new segment