diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d2a274f..4aa7f24 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -57,7 +57,7 @@ struct whisper_params { int32_t duration_ms = 0; int32_t max_context = -1; int32_t max_len = 0; - int32_t best_of = 5; + int32_t best_of = 2; int32_t beam_size = -1; float word_thold = 0.01f; diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index e0a53de..660b4df 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -43,6 +43,7 @@ struct whisper_params { bool speed_up = false; bool translate = false; + bool no_fallback = false; bool print_special = false; bool no_context = true; bool no_timestamps = false; @@ -73,6 +74,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } @@ -94,22 +96,23 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms); - fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms); - fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms); - fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id); - fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); - fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); - fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); - fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms); + fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms); + fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms); + fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id); + fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); + fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); + fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); fprintf(stderr, "\n"); } @@ -297,7 +300,8 @@ int main(int argc, char ** argv) { wparams.speed_up = params.speed_up; // disable temperature fallback - wparams.temperature_inc = -1.0f; + //wparams.temperature_inc = -1.0f; + wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc; wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); diff --git a/whisper.cpp b/whisper.cpp index 8e9fa6c..3ba2f8e 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3220,7 +3220,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_initial_ts =*/ 1.0f, /*.length_penalty =*/ -1.0f, - /*.temperature_inc =*/ 0.0f, // TODO: temporary disabled until improve performance + /*.temperature_inc =*/ 0.4f, /*.entropy_thold =*/ 2.4f, /*.logprob_thold =*/ -1.0f, /*.no_speech_thold =*/ 0.6f, @@ -3252,13 +3252,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str case WHISPER_SAMPLING_GREEDY: { result.greedy = { - /*.best_of =*/ 1, + /*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding }; } break; case WHISPER_SAMPLING_BEAM_SEARCH: { result.beam_search = { - /*.beam_size =*/ 5, + /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding /*.patience =*/ -1.0f, };