diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 7058031..204703c 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -59,6 +59,7 @@ struct whisper_params { float word_thold = 0.01f; + bool speed_up = false; bool verbose = false; bool translate = false; bool output_txt = false; @@ -104,6 +105,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { params.max_len = std::stoi(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); + } else if (arg == "-su" || arg == "--speed-up") { + params.speed_up = true; } else if (arg == "-v" || arg == "--verbose") { params.verbose = true; } else if (arg == "--translate") { @@ -161,6 +164,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n"); fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len); fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold); + fprintf(stderr, " -su, --speed-up speed up audio by factor of 2 (faster processing, reduced accuracy, default: %s)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " --translate translate from source language to english\n"); fprintf(stderr, " -otxt, --output-txt output result in a text file\n"); @@ -454,7 +458,7 @@ int main(int argc, char ** argv) { std::vector pcmf32; { drwav wav; - + if (fname_inp == "-") { std::vector wav_data; { @@ -563,6 +567,8 @@ int main(int argc, char ** argv) { wparams.thold_pt = params.word_thold; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + wparams.speed_up = params.speed_up; + // this callback is called on each new segment if (!wparams.print_realtime) { wparams.new_segment_callback = whisper_print_segment_callback; diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index c39375a..718c815 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -41,6 +41,7 @@ struct whisper_params { int32_t length_ms = 10000; int32_t capture_id = -1; + bool speed_up = false; bool verbose = false; bool translate = false; bool no_context = true; @@ -68,6 +69,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { params.length_ms = std::stoi(argv[++i]); } else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); + } else if (arg == "-su" || arg == "--speed-up") { + params.speed_up = true; } else if (arg == "-v" || arg == "--verbose") { params.verbose = true; } else if (arg == "--translate") { @@ -113,6 +116,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " --step N audio step size in milliseconds (default: %d)\n", params.step_ms); fprintf(stderr, " --length N audio length in milliseconds (default: %d)\n", params.length_ms); fprintf(stderr, " -c ID, --capture ID capture device ID (default: -1)\n"); + fprintf(stderr, " -su, --speed-up speed up audio by factor of 2 (faster processing, reduced accuracy, default: %s)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " --translate translate from source language to english\n"); fprintf(stderr, " -kc, --keep-context keep text context from earlier audio (default: false)\n"); @@ -326,6 +330,8 @@ int main(int argc, char ** argv) { wparams.language = params.language.c_str(); wparams.n_threads = params.n_threads; + wparams.speed_up = params.speed_up; + if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); return 6; diff --git a/whisper.cpp b/whisper.cpp index 7078863..d894f69 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2031,6 +2031,7 @@ static bool log_mel_spectrogram( const int n_mel, const int n_threads, const whisper_filters & filters, + const bool speed_up, whisper_mel & mel) { // Hanning window @@ -2044,7 +2045,7 @@ static bool log_mel_spectrogram( mel.n_len = (n_samples)/fft_step; mel.data.resize(mel.n_mel*mel.n_len); - const int n_fft = 1 + fft_size/2; + const int n_fft = 1 + (speed_up ? fft_size/4 : fft_size/2); //printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len); //printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate); @@ -2091,6 +2092,13 @@ static bool log_mel_spectrogram( //} } + if (speed_up) { + // scale down in the frequency domain results in a speed up in the time domain + for (int j = 0; j < n_fft; j++) { + fft_out[j] = 0.5*(fft_out[2*j] + fft_out[2*j + 1]); + } + } + // mel spectrogram for (int j = 0; j < mel.n_mel; j++) { double sum = 0.0; @@ -2171,7 +2179,21 @@ void whisper_free(struct whisper_context * ctx) { int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { const int64_t t_start_us = ggml_time_us(); - if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, ctx->mel)) { + if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { + fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + ctx->t_mel_us = ggml_time_us() - t_start_us; + + return 0; +} + +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 +int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { + const int64_t t_start_us = ggml_time_us(); + + if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -2353,6 +2375,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.thold_ptsum =*/ 0.01f, /*.max_len =*/ 0, + /*.speed_up =*/ false, + /*.language =*/ "en", /*.greedy =*/ { @@ -2391,6 +2415,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.thold_ptsum =*/ 0.01f, /*.max_len =*/ 0, + /*.speed_up =*/ false, + /*.language =*/ "en", /*.greedy =*/ { @@ -2485,9 +2511,16 @@ int whisper_full( result_all.clear(); // compute log mel spectrogram - if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { - fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); - return -1; + if (params.speed_up) { + if (whisper_pcm_to_mel_phase_vocoder(ctx, samples, n_samples, params.n_threads) != 0) { + fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); + return -1; + } + } else { + if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) { + fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__); + return -1; + } } if (params.token_timestamps) { @@ -2673,16 +2706,19 @@ int whisper_full( if (tokens_cur[i].id > whisper_token_beg(ctx)) { const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); if (!text.empty()) { + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + if (params.print_realtime) { if (params.print_timestamps) { - printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str()); + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); } else { printf("%s", text.c_str()); fflush(stdout); } } - result_all.push_back({ t0, t1, text, {} }); + result_all.push_back({ tt0, tt1, text, {} }); for (int j = i0; j <= i; j++) { result_all.back().tokens.push_back(tokens_cur[j]); } @@ -2714,16 +2750,19 @@ int whisper_full( if (!text.empty()) { const auto t1 = seek + seek_delta; + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + if (params.print_realtime) { if (params.print_timestamps) { - printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str()); + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); } else { printf("%s", text.c_str()); fflush(stdout); } } - result_all.push_back({ t0, t1, text, {} }); + result_all.push_back({ tt0, tt1, text, {} }); for (int j = i0; j < (int) tokens_cur.size(); j++) { result_all.back().tokens.push_back(tokens_cur[j]); } diff --git a/whisper.h b/whisper.h index 4c112f4..ea677ea 100644 --- a/whisper.h +++ b/whisper.h @@ -202,6 +202,9 @@ extern "C" { float thold_ptsum; // timestamp token sum probability threshold (~0.01) int max_len; // max segment length in characters + // [EXPERIMENTAL] speed-up techniques + bool speed_up; // speed-up the audio by 2x using Phase Vocoder + const char * language; struct {