stream : "-kc" now enables context keeping from previous segment (#90)

By default, the context keeping is disabled
pull/170/head
Georgi Gerganov 2022-11-22 18:20:05 +02:00
parent 63ae03b8e0
commit 385236d1d3
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 17 additions and 13 deletions

View File

@ -336,7 +336,7 @@ int main(int argc, char ** argv) {
wparams.print_realtime = false; wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps; wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate; wparams.translate = params.translate;
wparams.no_context = params.no_context; wparams.no_context = true;
wparams.single_segment = true; wparams.single_segment = true;
wparams.max_tokens = params.max_tokens; wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str(); wparams.language = params.language.c_str();
@ -345,9 +345,9 @@ int main(int argc, char ** argv) {
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
wparams.prompt_tokens = prompt_tokens.data(); wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.size(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]); fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 6; return 6;
@ -399,12 +399,15 @@ int main(int argc, char ** argv) {
pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end()); pcmf32_old = std::vector<float>(pcmf32.end() - n_samples_keep, pcmf32.end());
// Add tokens of the last full length segment as the prompt // Add tokens of the last full length segment as the prompt
prompt_tokens.clear(); if (!params.no_context) {
const int n_segments = whisper_full_n_segments(ctx); prompt_tokens.clear();
for (int i = 0; i < n_segments; ++i) {
const int token_count = whisper_full_n_tokens(ctx, i); const int n_segments = whisper_full_n_segments(ctx);
for (int j = 0; j < token_count; ++j) { for (int i = 0; i < n_segments; ++i) {
prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j)); const int token_count = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < token_count; ++j) {
prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j));
}
} }
} }
} }

View File

@ -2590,9 +2590,9 @@ int whisper_full(
prompt_past.clear(); prompt_past.clear();
} }
// Prepend the prompt tokens to the prompt_past // prepend the prompt tokens to the prompt_past
if (params.prompt_tokens && params.prompt_n_tokens > 0) { if (params.prompt_tokens && params.prompt_n_tokens > 0) {
// Parse tokens from the pointer (it points to an std::vector) // parse tokens from the pointer
for (int i = 0; i < params.prompt_n_tokens; i++) { for (int i = 0; i < params.prompt_n_tokens; i++) {
prompt_past.push_back(params.prompt_tokens[i]); prompt_past.push_back(params.prompt_tokens[i]);
} }

View File

@ -208,7 +208,8 @@ extern "C" {
bool speed_up; // speed-up the audio by 2x using Phase Vocoder bool speed_up; // speed-up the audio by 2x using Phase Vocoder
int audio_ctx; // overwrite the audio context size (0 = use default) int audio_ctx; // overwrite the audio context size (0 = use default)
// std::vector<whisper_token>: tokens to provide the whisper model as initial prompt // tokens to provide the whisper model as initial prompt
// these are prepended to any existing text context from a previous call
const whisper_token * prompt_tokens; const whisper_token * prompt_tokens;
int prompt_n_tokens; int prompt_n_tokens;