diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index 32f93d6..9efc83c 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -336,7 +336,7 @@ int main(int argc, char ** argv) { wparams.print_realtime = false; wparams.print_timestamps = !params.no_timestamps; wparams.translate = params.translate; - wparams.no_context = params.no_context; + wparams.no_context = true; wparams.single_segment = true; wparams.max_tokens = params.max_tokens; wparams.language = params.language.c_str(); @@ -345,9 +345,9 @@ int main(int argc, char ** argv) { wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; - wparams.prompt_tokens = prompt_tokens.data(); - wparams.prompt_n_tokens = prompt_tokens.size(); - + wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); + wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); + if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); return 6; @@ -399,12 +399,15 @@ int main(int argc, char ** argv) { pcmf32_old = std::vector(pcmf32.end() - n_samples_keep, pcmf32.end()); // Add tokens of the last full length segment as the prompt - prompt_tokens.clear(); - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - const int token_count = whisper_full_n_tokens(ctx, i); - for (int j = 0; j < token_count; ++j) { - prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j)); + if (!params.no_context) { + prompt_tokens.clear(); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const int token_count = whisper_full_n_tokens(ctx, i); + for (int j = 0; j < token_count; ++j) { + prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j)); + } } } } diff --git a/whisper.cpp b/whisper.cpp index 28c5d26..6c2e0e0 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2590,9 +2590,9 @@ int whisper_full( prompt_past.clear(); } - // Prepend the prompt tokens to the prompt_past + // prepend the prompt tokens to the prompt_past if (params.prompt_tokens && params.prompt_n_tokens > 0) { - // Parse tokens from the pointer (it points to an std::vector) + // parse tokens from the pointer for (int i = 0; i < params.prompt_n_tokens; i++) { prompt_past.push_back(params.prompt_tokens[i]); } diff --git a/whisper.h b/whisper.h index 1b2a042..58a8872 100644 --- a/whisper.h +++ b/whisper.h @@ -208,7 +208,8 @@ extern "C" { bool speed_up; // speed-up the audio by 2x using Phase Vocoder int audio_ctx; // overwrite the audio context size (0 = use default) - // std::vector: tokens to provide the whisper model as initial prompt + // tokens to provide the whisper model as initial prompt + // these are prepended to any existing text context from a previous call const whisper_token * prompt_tokens; int prompt_n_tokens;