From 741abb162ce8207eca2562e59c03efefb21c0122 Mon Sep 17 00:00:00 2001 From: denersc Date: Wed, 20 Mar 2024 13:25:26 -0300 Subject: [PATCH] whisper : token-level timestamps with DTW (#1485) * whisper.cpp: impl dtw algo * WIP: producing and placing DTW timestamps on tokens * Fix compile and assertion errors. Attempt to DTW timestamp with single_segment=false. * Fix mistake causing incorrect alignment of dtw timestamps * implement N_TOP_MOST and CUSTOM alignment heads setting * whisper: fix typo on alignment heads enum * Fix issues related to changes in whisper.cpp * Fixed excessive memory use when using DTW timestamps. Other minor fixes to DTW timestamping function * decoder: save cross QKs only if requested * Calling median filter with ggml_map_custom1 * Reimpl aheads n_top_most and custom. Sanity checks on chosen aheads * Copying cross QKs from decoder backend correctly * dtw: cleanup * Fix incorrect n_frames passed to dtw when near end of audio * Fix aheads_masks_init for backend != CPU * whisper : minor style * main : add dtw (wip) * whisper: fix invalid memory access in aheads_masks_init * main : add dtw (cont) * whisper : minor --------- Co-authored-by: Georgi Gerganov --- examples/main/main.cpp | 51 +++- whisper.cpp | 581 ++++++++++++++++++++++++++++++++++++++++- whisper.h | 41 +++ 3 files changed, 652 insertions(+), 21 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 50185cc..caa800b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -26,17 +26,17 @@ void replace_all(std::string & s, const std::string & search, const std::string // command-line parameters struct whisper_params { - int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - int32_t n_processors = 1; - int32_t offset_t_ms = 0; - int32_t offset_n = 0; - int32_t duration_ms = 0; - int32_t progress_step = 5; - int32_t max_context = -1; - int32_t max_len = 0; - int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of; - int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size; - int32_t audio_ctx = 0; + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + int32_t n_processors = 1; + int32_t offset_t_ms = 0; + int32_t offset_n = 0; + int32_t duration_ms = 0; + int32_t progress_step = 5; + int32_t max_context = -1; + int32_t max_len = 0; + int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of; + int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size; + int32_t audio_ctx = 0; float word_thold = 0.01f; float entropy_thold = 2.40f; @@ -76,6 +76,8 @@ struct whisper_params { std::string openvino_encode_device = "CPU"; + std::string dtw = ""; + std::vector fname_inp = {}; std::vector fname_out = {}; }; @@ -149,6 +151,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } + else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else { @@ -208,6 +211,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); + fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); fprintf(stderr, "\n"); @@ -649,7 +653,8 @@ bool output_json( times_o(token.t0, token.t1, false); } value_i("id", token.id, false); - value_f("p", token.p, true); + value_f("p", token.p, false); + value_f("t_dtw", token.t_dtw, true); end_obj(j == (n - 1)); } end_arr(!params.diarize && !params.tinydiarize); @@ -889,6 +894,28 @@ int main(int argc, char ** argv) { struct whisper_context_params cparams = whisper_context_default_params(); cparams.use_gpu = params.use_gpu; + if (!params.dtw.empty()) { + cparams.dtw_token_timestamps = true; + cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE; + + if (params.dtw == "tiny") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY; + if (params.dtw == "tiny.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN; + if (params.dtw == "base") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE; + if (params.dtw == "base.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN; + if (params.dtw == "small") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL; + if (params.dtw == "small.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN; + if (params.dtw == "medium") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM; + if (params.dtw == "medium.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN; + if (params.dtw == "large.v1") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1; + if (params.dtw == "large.v2") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2; + if (params.dtw == "large.v3") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3; + + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) { + fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str()); + return 3; + } + } + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); if (ctx == nullptr) { diff --git a/whisper.cpp b/whisper.cpp index f601197..ed4c818 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -351,6 +351,35 @@ static const std::map> g_lang = { { "yue", { 99, "cantonese", } }, }; +// [EXPERIMENTAL] Token-level timestamps with DTW +static const whisper_ahead g_aheads_tiny_en[] = { {1, 0}, {2, 0}, {2, 5}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4} }; +static const whisper_ahead g_aheads_tiny[] = { {2, 2}, {3, 0}, {3, 2}, {3, 3}, {3, 4}, {3, 5} }; +static const whisper_ahead g_aheads_base_en[] = { {3, 3}, {4, 7}, {5, 1}, {5, 5}, {5, 7} }; +static const whisper_ahead g_aheads_base[] = { {3, 1}, {4, 2}, {4, 3}, {4, 7}, {5, 1}, {5, 2}, {5, 4}, {5, 6} }; +static const whisper_ahead g_aheads_small_en[] = { {6, 6}, {7, 0}, {7, 3}, {7, 8}, {8, 2}, {8, 5}, {8, 7}, {9, 0}, {9, 4}, {9, 8}, {9, 10}, {10, 0}, {10, 1}, {10, 2}, {10, 3}, {10, 6}, {10, 11}, {11, 2}, {11, 4} }; +static const whisper_ahead g_aheads_small[] = { {5, 3}, {5, 9}, {8, 0}, {8, 4}, {8, 7}, {8, 8}, {9, 0}, {9, 7}, {9, 9}, {10, 5} }; +static const whisper_ahead g_aheads_medium_en[] = { {11, 4}, {14, 1}, {14, 12}, {14, 14}, {15, 4}, {16, 0}, {16, 4}, {16, 9}, {17, 12}, {17, 14}, {18, 7}, {18, 10}, {18, 15}, {20, 0}, {20, 3}, {20, 9}, {20, 14}, {21, 12} }; +static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15}, {16, 1}, {20, 0}, {23, 4} }; +static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} }; +static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} }; +static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} }; + +static const std::map g_aheads { + { WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } }, + { WHISPER_AHEADS_TINY, { 6, g_aheads_tiny } }, + { WHISPER_AHEADS_BASE_EN, { 5, g_aheads_base_en } }, + { WHISPER_AHEADS_BASE, { 8, g_aheads_base } }, + { WHISPER_AHEADS_SMALL_EN, { 19, g_aheads_small_en } }, + { WHISPER_AHEADS_SMALL, { 10, g_aheads_small } }, + { WHISPER_AHEADS_MEDIUM_EN, { 18, g_aheads_medium_en } }, + { WHISPER_AHEADS_MEDIUM, { 6, g_aheads_medium } }, + { WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } }, + { WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } }, + { WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } }, +}; + +static std::vector get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head); + struct whisper_mel { int n_len; int n_len_org; @@ -750,6 +779,13 @@ struct whisper_decoder { mutable std::mt19937 rng; // used for sampling at t > 0.0 }; +// [EXPERIMENTAL] Token-level timestamps with DTW +struct whisper_aheads_masks { + std::vector m; // One mask per text layer. + struct ggml_context * ctx = nullptr; + ggml_backend_buffer_t buffer = nullptr; +}; + struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; @@ -823,6 +859,11 @@ struct whisper_state { std::vector energy; // PCM signal energy + // [EXPERIMENTAL] Token-level timestamps with DTW + whisper_aheads_masks aheads_masks; + ggml_tensor * aheads_cross_QKs = nullptr; + std::vector aheads_cross_QKs_data; + // [EXPERIMENTAL] speed-up techniques int32_t exp_n_audio_ctx = 0; // 0 - use default }; @@ -1027,6 +1068,132 @@ static void whisper_kv_cache_seq_cp( } } +// [EXPERIMENTAL] Token-level timestamps with DTW +static bool aheads_masks_init( + const whisper_context_params & cparams, + const whisper_hparams & hparams, + struct whisper_aheads_masks & aheads_masks, + ggml_backend_t backend) { + + const int32_t n_text_layer = hparams.n_text_layer; + const int32_t n_head = hparams.n_text_head; + + // Sanity checks + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) { + WHISPER_LOG_ERROR("%s: dtw_aheads_preset should be != DTW_AHEADS_NONE\n", __func__); + return false; + } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) { + if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) { + WHISPER_LOG_ERROR("%s: dtw_n_top must be between %d and %d for this model.", __func__, 1, n_text_layer); + return false; + } + } else { + const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset); + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) { + if (aheads.n_heads == 0) { + WHISPER_LOG_ERROR("%s: dtw_aheads.n_heads should be > 0", __func__); + return false; + } + if (aheads.heads == NULL) { + WHISPER_LOG_ERROR("%s: dtw_aheads.heads unset", __func__); + return false; + } + } + for (size_t i = 0; i < aheads.n_heads; ++i) { + if (aheads.heads[i].n_text_layer >= n_text_layer) { + WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer %d, but model only has %d text layers", __func__, aheads.heads[i].n_text_layer + 1, n_text_layer); + return false; + } + if (aheads.heads[i].n_text_layer < 0) { + WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer < 0", __func__); + return false; + } + if (aheads.heads[i].n_head >= n_head) { + WHISPER_LOG_ERROR("%s: tried to set alignment head on head %d, but model only has %d heads", __func__, aheads.heads[i].n_head + 1, n_head); + return false; + } + if (aheads.heads[i].n_head < 0) { + WHISPER_LOG_ERROR("%s: tried to set alignment head on head < 0", __func__); + return false; + } + } + } + + struct ggml_init_params params = { + /*.mem_size =*/ (size_t) static_cast(n_text_layer)*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + aheads_masks.ctx = ggml_init(params); + + if (!aheads_masks.ctx) { + WHISPER_LOG_ERROR("%s: failed to allocate memory for the aheads_masks context\n", __func__); + return false; + } + + for (int64_t il = 0; il < n_text_layer; ++il) { + auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head); + if (!aheads.empty()) { + aheads_masks.m.push_back(ggml_new_tensor_2d(aheads_masks.ctx, GGML_TYPE_F32, n_head, aheads.size())); + } else { + aheads_masks.m.push_back(nullptr); + } + } + + aheads_masks.buffer = ggml_backend_alloc_ctx_tensors(aheads_masks.ctx, backend); + if (!aheads_masks.buffer) { + WHISPER_LOG_ERROR("%s: failed to allocate memory for aheads_masks\n", __func__); + return false; + } + + // Set data on mask tensors + // Since this must be backend agnostic, we get tensor data with + // ggml_backend_tensor_get, copy our desired values and send it back + // to backend with ggml_backend_tensor_set + std::vector mask_data; + for (int64_t il = 0; il < n_text_layer; ++il) { + if (aheads_masks.m[il] != nullptr) { + auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head); + + size_t data_size = aheads_masks.m[il]->ne[0] * aheads_masks.m[il]->ne[1] * sizeof(float); + mask_data.resize(data_size); + ggml_backend_tensor_get(aheads_masks.m[il], mask_data.data(), 0, data_size); + memset(mask_data.data(), 0, data_size); + + for (size_t ih = 0; ih < aheads.size(); ++ih) { + size_t pos = (aheads[ih] + (ih * aheads_masks.m[il]->ne[0] * aheads[ih])); + float v = 1.0f; + memcpy(mask_data.data() + pos, &v, sizeof(float)); + } + + ggml_backend_tensor_set(aheads_masks.m[il], mask_data.data(), 0, data_size); + } + } + + if (aheads_masks.m.empty()) { + WHISPER_LOG_ERROR("%s: \n", __func__); + return false; + } + + return true; +} + +static void aheads_masks_free(struct whisper_aheads_masks & aheads_masks) { + ggml_free(aheads_masks.ctx); + ggml_backend_buffer_free(aheads_masks.buffer); + aheads_masks.ctx = nullptr; +} + +static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) { + size_t size = 0; + for (size_t i = 0; i < aheads_masks.m.size(); ++i) { + if (aheads_masks.m[i] != nullptr) + size += ggml_nbytes(aheads_masks.m[i]); + } + return size; +} + static ggml_backend_t whisper_backend_init(const whisper_context_params & params) { ggml_backend_t backend_gpu = NULL; @@ -2105,6 +2272,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( whisper_context & wctx, whisper_state & wstate, const whisper_batch & batch, + bool save_alignment_heads_QKs, bool worst_case) { const auto & model = wctx.model; const auto & hparams = model.hparams; @@ -2158,6 +2326,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * inpL = cur; + // [EXPERIMENTAL] Token-level timestamps with DTW + struct ggml_tensor * aheads_cross_QKs = nullptr; + for (int il = 0; il < n_layer; ++il) { const auto & layer = model.layers_decoder[il]; @@ -2337,6 +2508,24 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); + // [EXPERIMENTAL] Token-level timestamps with DTW + if (wctx.params.dtw_token_timestamps) { + if (wstate.aheads_masks.m[il] != nullptr) { + struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]); + aheads_KQs = ggml_transpose(ctx0, aheads_KQs); + aheads_KQs = ggml_cont(ctx0, aheads_KQs); + aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs); + aheads_KQs = ggml_transpose(ctx0, aheads_KQs); + aheads_KQs = ggml_cont(ctx0, aheads_KQs); + aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]); + if (aheads_cross_QKs == NULL) { + aheads_cross_QKs = aheads_KQs; + } else { + aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs); + } + } + } + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -2422,6 +2611,16 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); + // [EXPERIMENTAL] Token-level timestamps with DTW + if (wctx.params.dtw_token_timestamps && aheads_cross_QKs != nullptr) { + aheads_cross_QKs = ggml_transpose(ctx0, aheads_cross_QKs); + aheads_cross_QKs = ggml_cont(ctx0, aheads_cross_QKs); + if (save_alignment_heads_QKs) { + ggml_build_forward_expand(gf, aheads_cross_QKs); + wstate.aheads_cross_QKs = aheads_cross_QKs; + } + } + ggml_build_forward_expand(gf, logits); ggml_free(ctx0); @@ -2444,6 +2643,7 @@ static bool whisper_decode_internal( whisper_state & wstate, const whisper_batch & batch, const int n_threads, + bool save_alignment_heads_QKs, ggml_abort_callback abort_callback, void * abort_callback_data) { const int64_t t_start_us = ggml_time_us(); @@ -2475,7 +2675,7 @@ static bool whisper_decode_internal( { auto & alloc = wstate.alloc_decode.alloc; - ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, false); + ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false); if (!ggml_gallocr_alloc_graph(alloc, gf)) { // should never happen as we pre-allocate the memory @@ -3003,6 +3203,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); } + // [EXPERIMENTAL] Token-level timestamps with DTW + if (ctx->params.dtw_token_timestamps) { + if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) { + WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__); + whisper_free_state(state); + return nullptr; + } + const size_t memory_size = aheads_masks_nbytes(state->aheads_masks); + WHISPER_LOG_INFO("%s: alignment heads masks size = %ld B\n", __func__, memory_size); + } + #ifdef WHISPER_USE_COREML const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); @@ -3095,7 +3306,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0); - return whisper_build_graph_decoder(*ctx, *state, state->batch, true); + return whisper_build_graph_decoder(*ctx, *state, state->batch, ctx->params.dtw_token_timestamps, true); }); if (!ok) { @@ -3161,8 +3372,17 @@ int whisper_ctx_init_openvino_encoder( struct whisper_context_params whisper_context_default_params() { struct whisper_context_params result = { - /*.use_gpu =*/ true, - /*.gpu_device =*/ 0, + /*.use_gpu =*/ true, + /*.gpu_device =*/ 0, + + /*.dtw_token_timestamps =*/ false, + /*.dtw_aheads_preset =*/ WHISPER_AHEADS_NONE, + /*.dtw_n_top =*/ -1, + /*.dtw_aheads =*/ { + /*.n_heads =*/ 0, + /*.heads =*/ NULL, + }, + /*.dtw_mem_size =*/ 1024*1024*128, }; return result; } @@ -3357,6 +3577,9 @@ void whisper_free_state(struct whisper_state * state) { ggml_backend_free(state->backend); + // [EXPERIMENTAL] Token-level timestamps with DTW + aheads_masks_free(state->aheads_masks); + delete state; } } @@ -3476,7 +3699,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1); - if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, nullptr, nullptr)) { + if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, false, nullptr, nullptr)) { WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -4411,6 +4634,17 @@ static inline bool should_split_on_word(const char * txt, bool split_on_word) { return txt[0] == ' '; } +static void whisper_exp_compute_token_level_timestamps_dtw( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + int i_segment, + size_t n_segments, + int seek, + int n_frames, + int medfilt_width, + int n_threads); + // wrap the last segment to max_len characters // returns the number of new segments static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) { @@ -4779,7 +5013,7 @@ static whisper_token_data whisper_sample_token( const whisper_decoder & decoder, bool best) { whisper_token_data result = { - 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, + 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f, }; const auto & vocab = ctx.vocab; @@ -4897,7 +5131,7 @@ static std::vector whisper_sample_token_topk( const auto id = dist(decoder.rng); //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum); - result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); + result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, }); if (result[i].id >= vocab.token_beg) { result[i].tid = result[i].id; @@ -5259,7 +5493,7 @@ int whisper_full_with_state( whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0); - if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -7; } @@ -5559,7 +5793,7 @@ int whisper_full_with_state( assert(batch.n_tokens > 0); - if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -8; } @@ -5682,6 +5916,9 @@ int whisper_full_with_state( const auto & tokens_cur = best_decoder.sequence.tokens; + // [EXPERIMENTAL] Token-level timestamps with DTW + const auto n_segments_before = state->result_all.size(); + //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); // update prompt_past @@ -5799,6 +6036,17 @@ int whisper_full_with_state( } } + // FIXME: will timestamp offsets be correct? + // [EXPERIMENTAL] Token-level timestamps with DTW + { + const auto n_segments = state->result_all.size() - n_segments_before; + if (ctx->params.dtw_token_timestamps && n_segments) { + const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek); + whisper_exp_compute_token_level_timestamps_dtw( + ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads); + } + } + // update audio window seek += seek_delta; @@ -6601,6 +6849,321 @@ static void whisper_exp_compute_token_level_timestamps( //} } +// +// token level timestamps - dtw version +// + +// n_text_layer -> total text layers on model +// n_head -> total heads per text layer on model +static std::vector get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int n_text_layer, int n_head) { + std::vector ret; + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) { + return ret; + } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) { + if (il >= n_text_layer - cparams.dtw_n_top) { + for (int32_t i = 0; i < n_head; ++i) { + ret.push_back(i); + } + } + } else { + const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset); + for (size_t i = 0; i < aheads.n_heads; ++i) { + if (aheads.heads[i].n_text_layer == il) { + ret.push_back(aheads.heads[i].n_head); + } + } + } + return ret; +} + +// dtw + backtrace to return found path +// based on +// https://github.com/openai/whisper/blob/main/whisper/timing.py#L83 +static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) { + WHISPER_ASSERT(ggml_n_dims(x) == 2); + + int64_t N = x->ne[0]; + int64_t M = x->ne[1]; + struct ggml_tensor * cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1); + struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1); + + cost = ggml_set_f32(cost, INFINITY); + trace = ggml_set_f32(trace, -1); + ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0); + + // dtw + // supposedly can be optmized by computing diagonals in parallel ? + // Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most. + for (int64_t j = 1; j < M + 1; ++j) { + for (int64_t i = 1; i < N + 1; ++i) { + float c0 = ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0); + float c1 = ggml_get_f32_nd(cost, i - 1, j, 0, 0); + float c2 = ggml_get_f32_nd(cost, i, j - 1, 0, 0); + + float c; + int32_t t; + if (c0 < c1 && c0 < c2) { + c = c0; + t = 0; + } else if (c1 < c0 && c1 < c2) { + c = c1; + t = 1; + } else { + c = c2; + t = 2; + } + + c = ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c; + ggml_set_f32_nd(cost, i, j, 0, 0, c); + ggml_set_i32_nd(trace, i, j, 0, 0, t); + } + } + + // Backtrace + const int64_t BT_MAX_ROWS = N + M - 1; + struct ggml_tensor * bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2); + // trace[0, :] = 2; + for (int64_t i = 0; i < M + 1; ++i) + ggml_set_i32_nd(trace, 0, i, 0, 0, 2); + //trace[:, 0] = 1; + for (int64_t i = 0; i < N + 1; ++i) + ggml_set_i32_nd(trace, i, 0, 0, 0, 1); + int bt_row_idx = BT_MAX_ROWS - 1; + int64_t i = N; + int64_t j = M; + while (i > 0 || j > 0) { + ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1); + ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1); + --bt_row_idx; + + int32_t t = ggml_get_i32_nd(trace, i, j, 0, 0); + if (t == 0) { + --i; + --j; + } else if (t == 1) { + --i; + } else if (t == 2) { + --j; + } else { + WHISPER_ASSERT(0); + } + } + + // FIXME: manual clip/transpose might not be the most efficient way? (e.g. use ggml funcs) + // Clip + transpose + // This might not be entirely necessary for our case, but leaving it for now so output matrix + // is identical to dtw on openAI timing.py + const int64_t result_n_cols = BT_MAX_ROWS-bt_row_idx-1; + ggml_tensor * r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols); + for (int64_t i = 0; i < 2; ++i) { + for (int64_t j = 0; j < result_n_cols; ++j) { + int32_t v = ggml_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0); + ggml_set_i32_nd(r, i, j, 0, 0, v); + } + } + + return r; +} + +struct median_filter_user_data { + int filter_width; +}; + +static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata) { + int filter_width = ((median_filter_user_data *) userdata)->filter_width; + WHISPER_ASSERT(nth == 1); + WHISPER_ASSERT(ith == 0); + WHISPER_ASSERT(filter_width < a->ne[2]); + WHISPER_ASSERT(filter_width % 2); + WHISPER_ASSERT(ggml_n_dims(a) == 3); + WHISPER_ASSERT(a->type == GGML_TYPE_F32); + + std::vector filter; + filter.reserve(filter_width); + for (int64_t i = 0; i < a->ne[0]; ++i) { + for (int64_t j = 0; j < a->ne[1]; ++j) { + for (int64_t k = 0; k < a->ne[2]; ++k) { + for (int64_t off = -filter_width/2; off <= filter_width/2; ++off) { + // "reflect" padding + int64_t idx = k + off; + if (idx < 0) { + idx = -idx; + } else if (idx >= a->ne[2]) { + idx = 2*(a->ne[2] - 1) - idx; + } + + filter.push_back(ggml_get_f32_nd(a, i, j, idx, 0)); + } + std::sort(filter.begin(), filter.end()); + const float v = filter[filter.size()/2]; + ggml_set_f32_nd(dst, i, j, k, 0, v); + filter.clear(); + } + } + } +} + +static void whisper_exp_compute_token_level_timestamps_dtw( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + int i_segment, + size_t n_segments, + int seek, + int n_frames, + int medfilt_width, + int n_threads) +{ + const int n_audio_ctx = state->exp_n_audio_ctx > 0 ? state->exp_n_audio_ctx : ctx->model.hparams.n_audio_ctx; + WHISPER_ASSERT(medfilt_width % 2); + WHISPER_ASSERT(n_frames <= n_audio_ctx * 2); + WHISPER_ASSERT(ctx->params.dtw_aheads_preset != WHISPER_AHEADS_NONE); + + // FIXME: Allocating mem everytime we call this func + // Our ggml buffer should be pre-allocated somewhere during init and reused + // when we call this function + struct ggml_init_params gparams = { + /*.mem_size =*/ ctx->params.dtw_mem_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + struct ggml_context * gctx = ggml_init(gparams); + + // Build token sequence that will be passed to decoder + // sot + [lang] + text result + eot + std::vector tokens = { whisper_token_sot(ctx), }; + if (whisper_is_multilingual(ctx)) { + const int lang_id = whisper_lang_id(params.language); + state->lang_id = lang_id; + tokens.push_back(whisper_token_lang(ctx, lang_id)); + } + const size_t sot_sequence_length = tokens.size(); + tokens.push_back(whisper_token_not(ctx)); + for (size_t i = i_segment; i < i_segment + n_segments; ++i) { + auto & segment = state->result_all[i]; + for (auto &t: segment.tokens) { + // Only text tokens + if (t.id < whisper_token_eot(ctx)) { + tokens.push_back(t.id); + } + } + } + tokens.push_back(whisper_token_eot(ctx)); + + // Get result tokens, pass then along to decoder to get cross attention QKs + // used in timestamping + // Decoder already returns only alignment head QKs, already concatenated in + // one tensor. + whisper_kv_cache_clear(state->kv_self); + whisper_batch_prep_legacy(state->batch, tokens.data(), tokens.size(), 0, 0); + whisper_kv_cache_seq_rm(state->kv_self, 0, 0, -1); + if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) { + WHISPER_LOG_INFO("DECODER FAILED\n"); + WHISPER_ASSERT(0); + } + WHISPER_ASSERT(state->aheads_cross_QKs != nullptr); + + const auto n_audio_tokens = n_frames/2; + WHISPER_ASSERT(state->aheads_cross_QKs != NULL); + WHISPER_ASSERT(n_audio_tokens <= state->aheads_cross_QKs->ne[1]); + const auto n_tokens = state->aheads_cross_QKs->ne[0]; + const auto n_heads = state->aheads_cross_QKs->ne[2]; + + // Copy data from decoder buffer to a local CPU tensor, discarding unused audio + // tokens (i.e. discarding rows at the end of tensor) + // IN: Tensor with N_TOKENS*audio_ctx*N_ALIGNMENT_HEADS dims + // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims + WHISPER_ASSERT(state->aheads_cross_QKs->type == GGML_TYPE_F32); + WHISPER_ASSERT(ggml_is_contiguous(state->aheads_cross_QKs)); + ggml_tensor * w = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads); + auto & data = state->aheads_cross_QKs_data; + data.resize(n_tokens * n_audio_ctx * n_heads); + ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads); + for (int k = 0; k < n_heads; ++k) { + for (int j = 0; j < n_audio_tokens; ++j) { + memcpy( + (char *) w->data + j * w->nb[1] + k * w->nb[2], + data.data() + j * n_tokens + k * n_tokens * n_audio_ctx, + n_tokens * sizeof(float) + ); + } + } + + // Normalize - in original OpenAI code, this is done over dim=-2. In this case, + // we already permuted N_TOKENS dimension to columns on last loop, becase ggml_norm + // operates over columns. Afterwards, permute to a shape that facilitates mean + // operation (after median filter) + // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims + // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims + w = ggml_norm(gctx, w, 1e-9); + w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3); + + // Pass median filter - this is done over AUDIO_TOKENS dimension. + // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims + // OUT: Same dims + median_filter_user_data mf_user_data = {medfilt_width}; + w = ggml_map_custom1(gctx, w, median_filter, 1, &mf_user_data); + + // Take mean over columns, scale by -1, reshape to 2D tensor, remove SOT sequence and EOT + // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims + // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS dims + w = ggml_mean(gctx, w); + w = ggml_scale(gctx, w, -1.0); + w = ggml_reshape_2d(gctx, w, w->ne[1], w->ne[2]); + + // Remove SOT sequence and EOT + // Out dimension is (N_TOKENS-sot_sequence_length-1)*N_AUDIO_TOKENS + w = ggml_view_2d(gctx, w, w->ne[0] - sot_sequence_length - 1, w->ne[1], w->nb[1], sot_sequence_length * w->nb[0]); + + // Compute + struct ggml_cgraph * gf = ggml_new_graph(gctx); + ggml_build_forward_expand(gf, w); + ggml_graph_compute_with_ctx(gctx, gf, n_threads); + + ggml_tensor * alignment = dtw_and_backtrace(gctx, w); + + // Place timestamps on segments + int32_t last_v = 0; + auto seg_i = state->result_all.begin() + i_segment; + auto tok_i = seg_i->tokens.begin(); + for (int i = 0; i < alignment->ne[1]; ++i) { + int32_t v = ggml_get_i32_nd(alignment, 0, i, 0, 0); + if (v != last_v) { + int32_t time_index = ggml_get_i32_nd(alignment, 1, i, 0, 0); + int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio + last_v = v; + + // Skip non-text tokens + while (!(tok_i->id < whisper_token_eot(ctx))) { + ++tok_i; + if (tok_i == seg_i->tokens.end()) { + ++seg_i; + tok_i = seg_i->tokens.begin(); + } + } + + tok_i->t_dtw = timestamp; + ++tok_i; + if (tok_i == seg_i->tokens.end()) { + ++seg_i; + tok_i = seg_i->tokens.begin(); + } + } + } + + // Print DTW timestamps + /*for (size_t i = i_segment; i < i_segment + n_segments; ++i) { + auto & segment = state->result_all[i]; + for (auto &t: segment.tokens) { + const char * tok = whisper_token_to_str(ctx, t.id); + fprintf(stderr, "|%s|(%.2f) ", tok, (float)t.t_dtw/100); + } + fprintf(stderr, "\n"); + }*/ + + ggml_free(gctx); +} + void whisper_log_set(ggml_log_callback log_callback, void * user_data) { g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default; g_state.log_callback_user_data = user_data; diff --git a/whisper.h b/whisper.h index a5371eb..2754337 100644 --- a/whisper.h +++ b/whisper.h @@ -84,9 +84,45 @@ extern "C" { typedef int32_t whisper_token; typedef int32_t whisper_seq_id; + enum whisper_alignment_heads_preset { + WHISPER_AHEADS_NONE, + WHISPER_AHEADS_N_TOP_MOST, // All heads from the N-top-most text-layers + WHISPER_AHEADS_CUSTOM, + WHISPER_AHEADS_TINY_EN, + WHISPER_AHEADS_TINY, + WHISPER_AHEADS_BASE_EN, + WHISPER_AHEADS_BASE, + WHISPER_AHEADS_SMALL_EN, + WHISPER_AHEADS_SMALL, + WHISPER_AHEADS_MEDIUM_EN, + WHISPER_AHEADS_MEDIUM, + WHISPER_AHEADS_LARGE_V1, + WHISPER_AHEADS_LARGE_V2, + WHISPER_AHEADS_LARGE_V3, + }; + + typedef struct whisper_ahead { + int n_text_layer; + int n_head; + } whisper_ahead; + + typedef struct whisper_aheads { + size_t n_heads; + const whisper_ahead * heads; + } whisper_aheads; + struct whisper_context_params { bool use_gpu; int gpu_device; // CUDA device + + // [EXPERIMENTAL] Token-level timestamps with DTW + bool dtw_token_timestamps; + enum whisper_alignment_heads_preset dtw_aheads_preset; + + int dtw_n_top; + struct whisper_aheads dtw_aheads; + + size_t dtw_mem_size; // TODO: remove }; typedef struct whisper_token_data { @@ -103,6 +139,11 @@ extern "C" { int64_t t0; // start time of the token int64_t t1; // end time of the token + // [EXPERIMENTAL] Token-level timestamps with DTW + // do not use if you haven't computed token-level timestamps with dtw + // Roughly corresponds to the moment in audio in which the token was output + int64_t t_dtw; + float vlen; // voice length of the token } whisper_token_data;