From d5afebd37c8267ce83cd2ec6a5994f2f20941859 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 2 Nov 2022 21:18:20 +0200 Subject: [PATCH] whisper : token-level timestamp refactoring (#49, #120) This turned out pretty good overall. The algorithm has been moved from main.cpp to whisper.cpp and can be reused for all subtitles types. This means that now you can specify the maximum length of the generated lines. Simply provide the "-ml" argument specifying the max length in number of characters --- README.md | 3 +- examples/main/README.md | 5 +- examples/main/main.cpp | 422 +++++++------------------------------- whisper.cpp | 435 ++++++++++++++++++++++++++++++++++++++-- whisper.h | 23 ++- 5 files changed, 518 insertions(+), 370 deletions(-) diff --git a/README.md b/README.md index fdbc65e..a888880 100644 --- a/README.md +++ b/README.md @@ -101,13 +101,14 @@ options: -ot N, --offset-t N time offset in milliseconds (default: 0) -on N, --offset-n N segment index offset (default: 0) -mc N, --max-context N maximum number of text context tokens to store (default: max) + -ml N, --max-len N maximum segment length in characters (default: 0) -wt N, --word-thold N word timestamp probability threshold (default: 0.010000) -v, --verbose verbose output --translate translate from source language to english -otxt, --output-txt output result in a text file -ovtt, --output-vtt output result in a vtt file -osrt, --output-srt output result in a srt file - -owts, --output-words output word-level timestamps to a text file + -owts, --output-words output script for generating karaoke video -ps, --print_special print special tokens -pc, --print_colors print colors -nt, --no_timestamps do not print timestamps diff --git a/examples/main/README.md b/examples/main/README.md index 27f47ff..f2bf2a8 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -8,7 +8,6 @@ It can be used as a reference for using the `whisper.cpp` library in other proje usage: ./bin/main [options] file0.wav file1.wav ... -options: -h, --help show this help message and exit -s SEED, --seed SEED RNG seed (default: -1) -t N, --threads N number of threads to use during computation (default: 4) @@ -16,18 +15,20 @@ options: -ot N, --offset-t N time offset in milliseconds (default: 0) -on N, --offset-n N segment index offset (default: 0) -mc N, --max-context N maximum number of text context tokens to store (default: max) + -ml N, --max-len N maximum segment length in characters (default: 0) -wt N, --word-thold N word timestamp probability threshold (default: 0.010000) -v, --verbose verbose output --translate translate from source language to english -otxt, --output-txt output result in a text file -ovtt, --output-vtt output result in a vtt file -osrt, --output-srt output result in a srt file - -owts, --output-words output word-level timestamps to a text file + -owts, --output-words output script for generating karaoke video -ps, --print_special print special tokens -pc, --print_colors print colors -nt, --no_timestamps do not print timestamps -l LANG, --language LANG spoken language (default: en) -m FNAME, --model FNAME model path (default: models/ggml-base.en.bin) -f FNAME, --file FNAME input WAV file path + -h, --help show this help message and exit ``` diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 8343892..b589459 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -36,6 +36,7 @@ std::string to_timestamp(int64_t t, bool comma = false) { return std::string(buf); } +// helper function to replace substrings void replace_all(std::string & s, const std::string & search, const std::string & replace) { for (size_t pos = 0; ; pos += replace.length()) { pos = s.find(search, pos); @@ -45,31 +46,6 @@ void replace_all(std::string & s, const std::string & search, const std::string } } -// a cost-function that is high for text that takes longer to pronounce -float voice_length(const std::string & text) { - float res = 0.0f; - - for (size_t i = 0; i < text.size(); ++i) { - if (text[i] == ' ') { - res += 0.01f; - } else if (text[i] == ',') { - res += 2.00f; - } else if (text[i] == '.') { - res += 3.00f; - } else if (text[i] == '!') { - res += 3.00f; - } else if (text[i] == '?') { - res += 3.00f; - } else if (text[i] >= '0' && text[i] <= '9') { - res += 3.00f; - } else { - res += 1.00f; - } - } - - return res; -} - // command-line parameters struct whisper_params { int32_t seed = -1; // RNG seed, not used currently @@ -78,6 +54,7 @@ struct whisper_params { int32_t offset_t_ms = 0; int32_t offset_n = 0; int32_t max_context = -1; + int32_t max_len = 0; float word_thold = 0.01f; @@ -120,6 +97,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { params.offset_n = std::stoi(argv[++i]); } else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); + } else if (arg == "-ml" || arg == "--max-len") { + params.max_len = std::stoi(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-v" || arg == "--verbose") { @@ -176,13 +155,14 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms); fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n); fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n"); + fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len); fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold); fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " --translate translate from source language to english\n"); fprintf(stderr, " -otxt, --output-txt output result in a text file\n"); fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n"); fprintf(stderr, " -osrt, --output-srt output result in a srt file\n"); - fprintf(stderr, " -owts, --output-words output word-level timestamps to a text file\n"); + fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n"); fprintf(stderr, " -ps, --print_special print special tokens\n"); fprintf(stderr, " -pc, --print_colors print colors\n"); fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); @@ -192,65 +172,67 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, "\n"); } -void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) { +void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { const whisper_params & params = *(whisper_params *) user_data; const int n_segments = whisper_full_n_segments(ctx); - // print the last segment - const int i = n_segments - 1; - if (i == 0) { + // print the last n_new segments + const int s0 = n_segments - n_new; + if (s0 == 0) { printf("\n"); } - if (params.no_timestamps) { - if (params.print_colors) { - for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { - if (params.print_special_tokens == false) { - const whisper_token id = whisper_full_get_token_id(ctx, i, j); - if (id >= whisper_token_eot(ctx)) { - continue; + for (int i = s0; i < n_segments; i++) { + if (params.no_timestamps) { + if (params.print_colors) { + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (params.print_special_tokens == false) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } } + + const char * text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p (ctx, i, j); + + const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + + printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); } - - const char * text = whisper_full_get_token_text(ctx, i, j); - const float p = whisper_full_get_token_p (ctx, i, j); - - const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); - - printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + } else { + const char * text = whisper_full_get_segment_text(ctx, i); + printf("%s", text); } + fflush(stdout); } else { - const char * text = whisper_full_get_segment_text(ctx, i); - printf("%s", text); - } - fflush(stdout); - } else { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - if (params.print_colors) { - printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); - for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { - if (params.print_special_tokens == false) { - const whisper_token id = whisper_full_get_token_id(ctx, i, j); - if (id >= whisper_token_eot(ctx)) { - continue; + if (params.print_colors) { + printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (params.print_special_tokens == false) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } } + + const char * text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p (ctx, i, j); + + const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + + printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); } + printf("\n"); + } else { + const char * text = whisper_full_get_segment_text(ctx, i); - const char * text = whisper_full_get_token_text(ctx, i, j); - const float p = whisper_full_get_token_p (ctx, i, j); - - const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); - - printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); } - printf("\n"); - } else { - const char * text = whisper_full_get_segment_text(ctx, i); - - printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); } } } @@ -320,297 +302,41 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_ return true; } -// word-level timestamps (experimental) -// TODO: make ffmpeg output optional -// TODO: extra pass to detect unused speech and assign to tokens +// karaoke video generation +// outputs a bash script that uses ffmpeg to generate a video with the subtitles // TODO: font parameter adjustments -// TODO: move to whisper.h/whisper.cpp and add parameter to select max line-length of subtitles -bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, const std::vector & pcmf32) { - std::vector pcm_avg(pcmf32.size(), 0); - - // average the fabs of the signal - { - const int hw = 32; - - for (int i = 0; i < pcmf32.size(); i++) { - float sum = 0; - for (int j = -hw; j <= hw; j++) { - if (i + j >= 0 && i + j < pcmf32.size()) { - sum += fabs(pcmf32[i + j]); - } - } - pcm_avg[i] = sum/(2*hw + 1); - } - } - - struct token_info { - int64_t t0 = -1; - int64_t t1 = -1; - - int64_t tt0 = -1; - int64_t tt1 = -1; - - whisper_token id; - whisper_token tid; - - float p = 0.0f; - float pt = 0.0f; - float ptsum = 0.0f; - - std::string text; - float vlen = 0.0f; // voice length of this token - }; - - int64_t t_beg = 0; - int64_t t_last = 0; - - whisper_token tid_last = 0; - +bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) { std::ofstream fout(fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + // TODO: become parameter + static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; + fout << "!/bin/bash" << "\n"; fout << "\n"; - fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE << ":rate=25:color=black -vf \""; - - bool is_first = true; + fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \""; for (int i = 0; i < whisper_full_n_segments(ctx); i++) { const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - const char *text = whisper_full_get_segment_text(ctx, i); - - const int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100)); - const int s1 = std::min((int) pcmf32.size(), (int) (t1*WHISPER_SAMPLE_RATE/100)); - const int n = whisper_full_n_tokens(ctx, i); - std::vector tokens(n); - - if (n <= 1) { - continue; - } - + std::vector tokens(n); for (int j = 0; j < n; ++j) { - struct whisper_token_data token = whisper_full_get_token_data(ctx, i, j); - - if (j == 0) { - if (token.id == whisper_token_beg(ctx)) { - tokens[j ].t0 = t0; - tokens[j ].t1 = t0; - tokens[j + 1].t0 = t0; - - t_beg = t0; - t_last = t0; - tid_last = whisper_token_beg(ctx); - } else { - tokens[j ].t0 = t_last; - } - } - - const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx)); - - tokens[j].id = token.id; - tokens[j].tid = token.tid; - tokens[j].p = token.p; - tokens[j].pt = token.pt; - tokens[j].ptsum = token.ptsum; - - tokens[j].text = whisper_token_to_str(ctx, token.id); - tokens[j].vlen = voice_length(tokens[j].text); - - if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last && tt <= t1) { - if (j > 0) { - tokens[j - 1].t1 = tt; - } - tokens[j].t0 = tt; - tid_last = token.tid; - } + tokens[j] = whisper_full_get_token_data(ctx, i, j); } - tokens[n - 2].t1 = t1; - tokens[n - 1].t0 = t1; - tokens[n - 1].t1 = t1; - - t_last = t1; - - // find intervals of tokens with unknown timestamps - // fill the timestamps by proportionally splitting the interval based on the token voice lengths - { - int p0 = 0; - int p1 = 0; - while (true) { - while (p1 < n && tokens[p1].t1 < 0) { - p1++; - } - - if (p1 >= n) { - p1--; - } - - if (p1 > p0) { - double psum = 0.0; - for (int j = p0; j <= p1; j++) { - psum += tokens[j].vlen; - } - - //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); - - const double dt = tokens[p1].t1 - tokens[p0].t0; - - // split the time proportionally to the voice length - for (int j = p0 + 1; j <= p1; j++) { - const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; - - tokens[j - 1].t1 = ct; - tokens[j ].t0 = ct; - } - } - - p1++; - p0 = p1; - if (p1 >= n) { - break; - } - } - } - - // fix up (just in case) - for (int j = 0; j < n - 1; j++) { - if (tokens[j].t1 < 0) { - tokens[j + 1].t0 = tokens[j].t1; - } - - if (j > 0) { - if (tokens[j - 1].t1 > tokens[j].t0) { - tokens[j].t0 = tokens[j - 1].t1; - tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); - } - } - - tokens[j].tt0 = tokens[j].t0; - tokens[j].tt1 = tokens[j].t1; - } - - // VAD - // expand or contract tokens based on voice activity - { - const int hw = WHISPER_SAMPLE_RATE/8; - - for (int j = 0; j < n; j++) { - if (tokens[j].id >= whisper_token_eot(ctx)) { - continue; - } - - const int64_t t0 = tokens[j].t0; - const int64_t t1 = tokens[j].t1; - - int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100)); - int s1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100)); - - const int ss0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100) - hw); - const int ss1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100) + hw); - - const int n = ss1 - ss0; - - float sum = 0.0f; - - for (int k = ss0; k < ss1; k++) { - sum += pcm_avg[k]; - } - - const float thold = 0.5*sum/n; - - { - int k = s0; - if (pcm_avg[k] > thold && j > 0) { - while (k > 0 && pcm_avg[k] > thold) { - k--; - } - tokens[j].t0 = (int64_t) (100*k/WHISPER_SAMPLE_RATE); - if (tokens[j].t0 < tokens[j - 1].t1) { - tokens[j].t0 = tokens[j - 1].t1; - } else { - s0 = k; - } - } else { - while (pcm_avg[k] < thold && k < s1) { - k++; - } - s0 = k; - tokens[j].t0 = 100*k/WHISPER_SAMPLE_RATE; - } - } - - { - int k = s1; - if (pcm_avg[k] > thold) { - while (k < (int) pcmf32.size() - 1 && pcm_avg[k] > thold) { - k++; - } - tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE; - if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) { - tokens[j].t1 = tokens[j + 1].t0; - } else { - s1 = k; - } - } else { - while (pcm_avg[k] < thold && k > s0) { - k--; - } - s1 = k; - tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE; - } - } - } - } - - // fixed token expand (optional) - { - const int t_expand = 0; - - for (int j = 0; j < n; j++) { - if (j > 0) { - tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); - } - if (j < n - 1) { - tokens[j].t1 = tokens[j].t1 + t_expand; - } - } - } - - // debug info - // TODO: toggle via parameter - for (int j = 0; j < n; ++j) { - const auto & token = tokens[j]; - const auto tt = token.pt > params.word_thold && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]"; - printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, - tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, token.text.c_str()); - - if (tokens[j].id >= whisper_token_eot(ctx)) { - continue; - } - - //printf("[%s --> %s] %s\n", to_timestamp(token.t0).c_str(), to_timestamp(token.t1).c_str(), whisper_token_to_str(ctx, token.id)); - - //fout << "# " << to_timestamp(token.t0) << " --> " << to_timestamp(token.t1) << " " << whisper_token_to_str(ctx, token.id) << "\n"; - } - - // TODO: become parameters - static const int line_wrap = 60; - static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; - - if (!is_first) { + if (i > 0) { fout << ","; } // background text fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'"; - is_first = false; + bool is_first = true; for (int j = 0; j < n; ++j) { const auto & token = tokens[j]; @@ -654,17 +380,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f } ncnt += txt.size(); - - if (ncnt > line_wrap) { - if (k < j) { - txt_bg = "> "; - txt_fg = "> "; - txt_ul = "\\ \\ "; - ncnt = 0; - } else { - break; - } - } } ::replace_all(txt_bg, "'", "’"); @@ -673,8 +388,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f ::replace_all(txt_fg, "\"", "\\\""); } - // background text - fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << token.tt0/100.0 << "," << token.tt1/100.0 << ")'"; + if (is_first) { + // background text + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'"; + is_first = false; + } // foreground text fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'"; @@ -815,6 +533,10 @@ int main(int argc, char ** argv) { wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; wparams.offset_ms = params.offset_t_ms; + wparams.token_timestamps = params.output_wts || params.max_len > 0; + wparams.thold_pt = params.word_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + // this callback is called on each new segment if (!wparams.print_realtime) { wparams.new_segment_callback = whisper_print_segment_callback; @@ -852,7 +574,7 @@ int main(int argc, char ** argv) { // output to WTS file if (params.output_wts) { const auto fname_wts = fname_inp + ".wts"; - output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, pcmf32); + output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); } } } diff --git a/whisper.cpp b/whisper.cpp index b230d0c..02ab5cb 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -418,6 +418,12 @@ struct whisper_context { std::vector result_all; std::vector prompt_past; + + // [EXPERIMENTAL] token-level timestamps data + int64_t t_beg; + int64_t t_last; + whisper_token tid_last; + std::vector energy; // PCM signal energy }; // load the model from a ggml file @@ -431,7 +437,7 @@ struct whisper_context { // // see the convert-pt-to-ggml.py script for details // -bool whisper_model_load(const std::string & fname, whisper_context & wctx) { +static bool whisper_model_load(const std::string & fname, whisper_context & wctx) { fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str()); auto & model = wctx.model; @@ -1062,7 +1068,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { // - n_threads: number of threads to use // - mel_offset: offset in the mel spectrogram (i.e. audio offset) // -bool whisper_encode( +static bool whisper_encode( whisper_context & wctx, const int n_threads, const int mel_offset) { @@ -1448,7 +1454,7 @@ bool whisper_encode( // - n_tokens: number of tokens in the prompt // - n_past: number of past tokens to prefix the prompt with // -bool whisper_decode( +static bool whisper_decode( whisper_context & wctx, const int n_threads, const whisper_token * tokens, @@ -1811,10 +1817,12 @@ bool whisper_decode( } // the most basic sampling scheme - select the top token -whisper_token_data whisper_sample_best( +static whisper_token_data whisper_sample_best( const whisper_vocab & vocab, const float * probs) { - whisper_token_data result; + whisper_token_data result = { + 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, + }; int n_logits = vocab.id_to_token.size(); @@ -1887,7 +1895,7 @@ whisper_token_data whisper_sample_best( } // samples only from the timestamps tokens -whisper_vocab::id whisper_sample_timestamp( +static whisper_vocab::id whisper_sample_timestamp( const whisper_vocab & vocab, const float * probs) { int n_logits = vocab.id_to_token.size(); @@ -1939,7 +1947,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) { // naive Discrete Fourier Transform // input is real-valued // output is complex-valued -void dft(const std::vector & in, std::vector & out) { +static void dft(const std::vector & in, std::vector & out) { int N = in.size(); out.resize(N*2); @@ -1963,7 +1971,7 @@ void dft(const std::vector & in, std::vector & out) { // poor man's implementation - use something better // input is real-valued // output is complex-valued -void fft(const std::vector & in, std::vector & out) { +static void fft(const std::vector & in, std::vector & out) { out.resize(in.size()*2); int N = in.size(); @@ -2014,7 +2022,7 @@ void fft(const std::vector & in, std::vector & out) { } // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 -bool log_mel_spectrogram( +static bool log_mel_spectrogram( const float * samples, const int n_samples, const int sample_rate, @@ -2339,6 +2347,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.print_realtime =*/ false, /*.print_timestamps =*/ true, + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.language =*/ "en", /*.greedy =*/ { @@ -2371,6 +2384,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.print_realtime =*/ false, /*.print_timestamps =*/ true, + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.language =*/ "en", /*.greedy =*/ { @@ -2392,6 +2410,68 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str return result; } +// forward declarations +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context * ctx, + int i_segment, + float thold_pt, + float thold_ptsum); + +// wrap the last segment to max_len characters +// returns the number of new segments +static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { + auto segment = ctx->result_all.back(); + + int res = 1; + int acc = 0; + + std::string text; + + for (int i = 0; i < (int) segment.tokens.size(); i++) { + const auto & token = segment.tokens[i]; + if (token.id >= whisper_token_eot(ctx)) { + continue; + } + + const auto txt = whisper_token_to_str(ctx, token.id); + + const int cur = strlen(txt); + + if (acc + cur > max_len && i > 0) { + // split here + ctx->result_all.back().text = std::move(text); + ctx->result_all.back().t1 = token.t0; + ctx->result_all.back().tokens.resize(i); + + ctx->result_all.push_back({}); + ctx->result_all.back().t0 = token.t0; + ctx->result_all.back().t1 = segment.t1; + + // add tokens [i, end] to the new segment + ctx->result_all.back().tokens.insert( + ctx->result_all.back().tokens.end(), + segment.tokens.begin() + i, + segment.tokens.end()); + + acc = 0; + text = ""; + + segment = ctx->result_all.back(); + i = -1; + + res++; + } else { + acc += cur; + text += txt; + } + } + + ctx->result_all.back().text = std::move(text); + + return res; +} + int whisper_full( struct whisper_context * ctx, struct whisper_full_params params, @@ -2408,6 +2488,13 @@ int whisper_full( return -1; } + if (params.token_timestamps) { + ctx->t_beg = 0; + ctx->t_last = 0; + ctx->tid_last = 0; + ctx->energy = get_signal_energy(samples, n_samples, 32); + } + const int seek_start = params.offset_ms/10; // if length of spectrogram is less than 1s (100 samples), then return @@ -2557,6 +2644,7 @@ int whisper_full( } } + // shrink down to result_len tokens_cur.resize(result_len); for (const auto & r : tokens_cur) { @@ -2595,8 +2683,19 @@ int whisper_full( for (int j = i0; j <= i; j++) { result_all.back().tokens.push_back(tokens_cur[j]); } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(ctx, params.max_len); + } + } if (params.new_segment_callback) { - params.new_segment_callback(ctx, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); } } text = ""; @@ -2625,8 +2724,19 @@ int whisper_full( for (int j = i0; j < (int) tokens_cur.size(); j++) { result_all.back().tokens.push_back(tokens_cur[j]); } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(ctx, params.max_len); + } + } if (params.new_segment_callback) { - params.new_segment_callback(ctx, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); } } } @@ -2760,7 +2870,7 @@ int whisper_full_parallel( // call the new_segment_callback for each segment if (params.new_segment_callback) { - params.new_segment_callback(ctx, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data); } } @@ -2836,3 +2946,304 @@ const char * whisper_print_system_info() { return s.c_str(); } + +// ================================================================================================= + +// +// Experimental stuff below +// +// Not sure if these should be part of the library at all, because the quality of the results is not +// guaranteed. Might get removed at some point unless a robust algorithm implementation is found +// + +// ================================================================================================= + +// +// token-level timestamps +// + +static int timestamp_to_sample(int64_t t, int n_samples) { + return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); +} + +static int64_t sample_to_timestamp(int i_sample) { + return (100*i_sample)/WHISPER_SAMPLE_RATE; +} + +// a cost-function / heuristic that is high for text that takes longer to pronounce +// obviously, can be improved +static float voice_length(const std::string & text) { + float res = 0.0f; + + for (size_t i = 0; i < text.size(); ++i) { + if (text[i] == ' ') { + res += 0.01f; + } else if (text[i] == ',') { + res += 2.00f; + } else if (text[i] == '.') { + res += 3.00f; + } else if (text[i] == '!') { + res += 3.00f; + } else if (text[i] == '?') { + res += 3.00f; + } else if (text[i] >= '0' && text[i] <= '9') { + res += 3.00f; + } else { + res += 1.00f; + } + } + + return res; +} + +// average the fabs of the signal +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) { + const int hw = n_samples_per_half_window; + + std::vector result(n_samples); + + for (int i = 0; i < n_samples; i++) { + float sum = 0; + for (int j = -hw; j <= hw; j++) { + if (i + j >= 0 && i + j < n_samples) { + sum += fabs(signal[i + j]); + } + } + result[i] = sum/(2*hw + 1); + } + + return result; +} + +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context * ctx, + int i_segment, + float thold_pt, + float thold_ptsum) { + auto & segment = ctx->result_all[i_segment]; + auto & tokens = segment.tokens; + + const int n_samples = ctx->energy.size(); + + if (n_samples == 0) { + fprintf(stderr, "%s: no signal data available\n", __func__); + return; + } + + const int64_t t0 = segment.t0; + const int64_t t1 = segment.t1; + + const int s0 = timestamp_to_sample(t0, n_samples); + const int s1 = timestamp_to_sample(t1, n_samples); + + const int n = tokens.size(); + + if (n == 0) { + return; + } + + if (n == 1) { + tokens[0].t0 = t0; + tokens[0].t1 = t1; + + return; + } + + auto & t_beg = ctx->t_beg; + auto & t_last = ctx->t_last; + auto & tid_last = ctx->tid_last; + + for (int j = 0; j < n; ++j) { + auto & token = tokens[j]; + + if (j == 0) { + if (token.id == whisper_token_beg(ctx)) { + tokens[j ].t0 = t0; + tokens[j ].t1 = t0; + tokens[j + 1].t0 = t0; + + t_beg = t0; + t_last = t0; + tid_last = whisper_token_beg(ctx); + } else { + tokens[j ].t0 = t_last; + } + } + + const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx)); + + tokens[j].id = token.id; + tokens[j].tid = token.tid; + tokens[j].p = token.p; + tokens[j].pt = token.pt; + tokens[j].ptsum = token.ptsum; + + tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id)); + + if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { + if (j > 0) { + tokens[j - 1].t1 = tt; + } + tokens[j].t0 = tt; + tid_last = token.tid; + } + } + + tokens[n - 2].t1 = t1; + tokens[n - 1].t0 = t1; + tokens[n - 1].t1 = t1; + + t_last = t1; + + // find intervals of tokens with unknown timestamps + // fill the timestamps by proportionally splitting the interval based on the token voice lengths + { + int p0 = 0; + int p1 = 0; + + while (true) { + while (p1 < n && tokens[p1].t1 < 0) { + p1++; + } + + if (p1 >= n) { + p1--; + } + + if (p1 > p0) { + double psum = 0.0; + for (int j = p0; j <= p1; j++) { + psum += tokens[j].vlen; + } + + //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); + + const double dt = tokens[p1].t1 - tokens[p0].t0; + + // split the time proportionally to the voice length + for (int j = p0 + 1; j <= p1; j++) { + const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; + + tokens[j - 1].t1 = ct; + tokens[j ].t0 = ct; + } + } + + p1++; + p0 = p1; + if (p1 >= n) { + break; + } + } + } + + // fix up (just in case) + for (int j = 0; j < n - 1; j++) { + if (tokens[j].t1 < 0) { + tokens[j + 1].t0 = tokens[j].t1; + } + + if (j > 0) { + if (tokens[j - 1].t1 > tokens[j].t0) { + tokens[j].t0 = tokens[j - 1].t1; + tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); + } + } + } + + // VAD + // expand or contract tokens based on voice activity + { + const int hw = WHISPER_SAMPLE_RATE/8; + + for (int j = 0; j < n; j++) { + if (tokens[j].id >= whisper_token_eot(ctx)) { + continue; + } + + int s0 = timestamp_to_sample(tokens[j].t0, n_samples); + int s1 = timestamp_to_sample(tokens[j].t1, n_samples); + + const int ss0 = std::max(s0 - hw, 0); + const int ss1 = std::min(s1 + hw, n_samples); + + const int ns = ss1 - ss0; + + float sum = 0.0f; + + for (int k = ss0; k < ss1; k++) { + sum += ctx->energy[k]; + } + + const float thold = 0.5*sum/ns; + + { + int k = s0; + if (ctx->energy[k] > thold && j > 0) { + while (k > 0 && ctx->energy[k] > thold) { + k--; + } + tokens[j].t0 = sample_to_timestamp(k); + if (tokens[j].t0 < tokens[j - 1].t1) { + tokens[j].t0 = tokens[j - 1].t1; + } else { + s0 = k; + } + } else { + while (ctx->energy[k] < thold && k < s1) { + k++; + } + s0 = k; + tokens[j].t0 = sample_to_timestamp(k); + } + } + + { + int k = s1; + if (ctx->energy[k] > thold) { + while (k < n_samples - 1 && ctx->energy[k] > thold) { + k++; + } + tokens[j].t1 = sample_to_timestamp(k); + if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) { + tokens[j].t1 = tokens[j + 1].t0; + } else { + s1 = k; + } + } else { + while (ctx->energy[k] < thold && k > s0) { + k--; + } + s1 = k; + tokens[j].t1 = sample_to_timestamp(k); + } + } + } + } + + // fixed token expand (optional) + //{ + // const int t_expand = 0; + + // for (int j = 0; j < n; j++) { + // if (j > 0) { + // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); + // } + // if (j < n - 1) { + // tokens[j].t1 = tokens[j].t1 + t_expand; + // } + // } + //} + + // debug info + //for (int j = 0; j < n; ++j) { + // const auto & token = tokens[j]; + // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]"; + // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, + // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id)); + + // if (tokens[j].id >= whisper_token_eot(ctx)) { + // continue; + // } + //} +} diff --git a/whisper.h b/whisper.h index 5d7c40d..57ea5db 100644 --- a/whisper.h +++ b/whisper.h @@ -68,14 +68,21 @@ extern "C" { typedef int whisper_token; - struct whisper_token_data { + typedef struct whisper_token_data { whisper_token id; // token id whisper_token tid; // forced timestamp token id float p; // probability of the token float pt; // probability of the timestamp token float ptsum; // sum of probabilities of all timestamp tokens - }; + + // token-level timestamp data + // do not use if you haven't computed token-level timestamps + int64_t t0; // start time of the token + int64_t t1; // end time of the token + + float vlen; // voice length of the token + } whisper_token_data; // Allocates all memory needed for the model and loads the model from the given file. // Returns NULL on failure. @@ -129,7 +136,7 @@ extern "C" { // You can also implement your own sampling method using the whisper_get_probs() function. // whisper_sample_best() returns the token with the highest probability // whisper_sample_timestamp() returns the most probable timestamp token - WHISPER_API struct whisper_token_data whisper_sample_best(struct whisper_context * ctx); + WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx); // Return the id of the specified language, returns -1 if not found @@ -172,7 +179,7 @@ extern "C" { // Text segment callback // Called on every newly generated text segment // Use the whisper_full_...() functions to obtain the text segments - typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data); + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); struct whisper_full_params { enum whisper_sampling_strategy strategy; @@ -188,6 +195,12 @@ extern "C" { bool print_realtime; bool print_timestamps; + // [EXPERIMENTAL] token-level timestamps + bool token_timestamps; // enable token-level timestamps + float thold_pt; // timestamp token probability threshold (~0.01) + float thold_ptsum; // timestamp token sum probability threshold (~0.01) + int max_len; // max segment length in characters + const char * language; struct { @@ -244,7 +257,7 @@ extern "C" { // Get token data for the specified token in the specified segment. // This contains probabilities, timestamps, etc. - WHISPER_API struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); // Get the probability of the specified token in the specified segment. WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);