diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index 040ba9e..4ff93d3 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -40,6 +40,7 @@ struct whisper_params { int32_t step_ms = 3000; int32_t length_ms = 10000; int32_t capture_id = -1; + int32_t audio_ctx = 0; bool speed_up = false; bool verbose = false; @@ -69,6 +70,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { params.length_ms = std::stoi(argv[++i]); } else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); + } else if (arg == "-ac" || arg == "--audio_ctx") { + params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-v" || arg == "--verbose") { @@ -116,6 +119,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " --step N audio step size in milliseconds (default: %d)\n", params.step_ms); fprintf(stderr, " --length N audio length in milliseconds (default: %d)\n", params.length_ms); fprintf(stderr, " -c ID, --capture ID capture device ID (default: -1)\n"); + fprintf(stderr, " -ac N, --audio_ctx N audio context size (default: %d, 0 - all)\n", params.audio_ctx); fprintf(stderr, " -su, --speed-up speed up audio by factor of 2 (faster processing, reduced accuracy, default: %s)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " --translate translate from source language to english\n"); @@ -322,7 +326,6 @@ int main(int argc, char ** argv) { { whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - wparams.max_tokens = 32; wparams.print_progress = false; wparams.print_special_tokens = params.print_special_tokens; wparams.print_realtime = false; @@ -330,9 +333,11 @@ int main(int argc, char ** argv) { wparams.translate = params.translate; wparams.no_context = params.no_context; wparams.single_segment = true; + wparams.max_tokens = 32; wparams.language = params.language.c_str(); wparams.n_threads = params.n_threads; + wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { diff --git a/whisper.cpp b/whisper.cpp index 48f93eb..d35b90f 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -424,6 +424,9 @@ struct whisper_context { int64_t t_last; whisper_token tid_last; std::vector energy; // PCM signal energy + + // [EXPERIMENTAL] speed-up techniques + int32_t exp_n_audio_ctx; // 0 - use default }; // load the model from a ggml file @@ -974,9 +977,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); - - //memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); - //memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); } const size_t memory_size = @@ -1079,7 +1079,7 @@ static bool whisper_encode( const auto & mel_inp = wctx.mel; const auto & hparams = model.hparams; - const int n_ctx = WHISPER_EXPERIMENT_AUDIO_CTX; + const int n_ctx = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; const int n_state = hparams.n_audio_state; const int n_head = hparams.n_audio_head; const int n_layer = hparams.n_audio_layer; @@ -1133,6 +1133,8 @@ static bool whisper_encode( cur = ggml_gelu(ctx0, cur); } + // =================================================================== + // NOTE: experimenting with partial evaluation of the encoder (ignore) //static int iter = -1; //const int n_iter = 1500/n_ctx; @@ -1151,6 +1153,10 @@ static bool whisper_encode( struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur)); + // =================================================================== + + // original: + //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); struct ggml_tensor * inpL = cur; @@ -1494,8 +1500,7 @@ static bool whisper_decode( const int n_layer = hparams.n_text_layer; const int N = n_tokens; - //const int M = hparams.n_audio_ctx; - const int M = WHISPER_EXPERIMENT_AUDIO_CTX; + const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; struct ggml_init_params params = { .mem_size = wctx.buf_compute.size(), @@ -2405,6 +2410,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_tokens =*/ 0, /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, /*.language =*/ "en", @@ -2447,6 +2453,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_tokens =*/ 0, /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, /*.language =*/ "en", @@ -2577,6 +2584,9 @@ int whisper_full( prompt_past.clear(); } + // overwrite audio_ctx + ctx->exp_n_audio_ctx = params.audio_ctx; + // these tokens determine the task that will be performed std::vector prompt_init = { whisper_token_sot(ctx) }; if (whisper_is_multilingual(ctx)) { diff --git a/whisper.h b/whisper.h index 0211995..88cc711 100644 --- a/whisper.h +++ b/whisper.h @@ -24,8 +24,6 @@ #define WHISPER_HOP_LENGTH 160 #define WHISPER_CHUNK_SIZE 30 -#define WHISPER_EXPERIMENT_AUDIO_CTX 512 - #ifdef __cplusplus extern "C" { #endif @@ -207,7 +205,8 @@ extern "C" { int max_tokens; // max tokens per segment (0 = no limit) // [EXPERIMENTAL] speed-up techniques - bool speed_up; // speed-up the audio by 2x using Phase Vocoder + bool speed_up; // speed-up the audio by 2x using Phase Vocoder + int audio_ctx; // overwrite the audio context size (0 = use default) const char * language;