diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3e8c5aa..c6bf32e 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -66,6 +66,7 @@ struct whisper_params { bool speed_up = false; bool translate = false; + bool detect_language= false; bool diarize = false; bool split_on_word = false; bool no_fallback = false; @@ -141,6 +142,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; } else if ( arg == "--prompt") { params.prompt = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } @@ -191,6 +193,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); @@ -739,6 +742,9 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); } } + if (params.detect_language) { + params.language = "auto"; + } fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n", __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads, params.n_processors, @@ -761,6 +767,7 @@ int main(int argc, char ** argv) { wparams.print_special = params.print_special; wparams.translate = params.translate; wparams.language = params.language.c_str(); + wparams.detect_language = params.detect_language; wparams.n_threads = params.n_threads; wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; wparams.offset_ms = params.offset_t_ms; diff --git a/whisper.cpp b/whisper.cpp index df283ec..158aa0b 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3312,6 +3312,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.prompt_n_tokens =*/ 0, /*.language =*/ "en", + /*.detect_language =*/ false, /*.suppress_blank =*/ true, /*.suppress_non_speech_tokens =*/ false, @@ -3898,7 +3899,7 @@ int whisper_full_with_state( } // auto-detect language if not specified - if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) { + if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) { std::vector probs(whisper_lang_max_id() + 1, 0.0f); const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); @@ -3910,6 +3911,9 @@ int whisper_full_with_state( params.language = whisper_lang_str(lang_id); fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + if (params.detect_language) { + return 0; + } } if (params.token_timestamps) { diff --git a/whisper.h b/whisper.h index 3d689a4..2d5b3eb 100644 --- a/whisper.h +++ b/whisper.h @@ -365,6 +365,7 @@ extern "C" { // for auto-detection, set to nullptr, "" or "auto" const char * language; + bool detect_language; // common decoding parameters: bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89