From e3c5e2cba8173070a64321885ff3325a3a1563bb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 12 Feb 2024 19:53:51 +0200 Subject: [PATCH] whisper : fix external encoder (#1860) --- whisper.cpp | 41 +++++++++-------------------------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index dec9957..536adc3 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1659,22 +1659,9 @@ static struct ggml_cgraph * whisper_build_graph_conv( ggml_set_name(cur, "embd_conv"); wstate.embd_conv = cur; } else { -#ifdef WHISPER_USE_COREML - cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); - ggml_allocr_alloc(alloc, cur); + ggml_build_forward_expand(gf, mel); - if (!ggml_allocr_is_measure(alloc)) { - whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data); - } -#endif -#ifdef WHISPER_USE_OPENVINO cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); - ggml_allocr_alloc(alloc, cur); - - if (!ggml_allocr_is_measure(alloc)) { - whisper_openvino_encode(wstate.ctx_openvino, mel, cur); - } -#endif ggml_set_name(cur, "embd_enc"); wstate.embd_enc = cur; @@ -1708,14 +1695,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); - //ggml_allocr * alloc = wstate.alloc_encode.alloc; - - //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state); - //ggml_allocr_alloc(alloc, cur); - - //if (!ggml_allocr_is_measure(alloc)) { - // ggml_backend_tensor_copy(wstate.embd_conv, cur); - //} struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv); const float KQscale = 1.0f/sqrtf(float(n_state)/n_head); @@ -1957,14 +1936,6 @@ static struct ggml_cgraph * whisper_build_graph_cross( ggml_cgraph * gf = ggml_new_graph(ctx0); - //ggml_allocr * alloc = wstate.alloc_cross.alloc; - - //struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); - //ggml_allocr_alloc(alloc, cur); - - //if (!ggml_allocr_is_measure(alloc)) { - // ggml_backend_tensor_copy(wstate.embd_enc, cur); - //} struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc); const float Kscale = pow(float(n_state) / n_head, -0.25); @@ -2037,13 +2008,13 @@ static bool whisper_encode_internal( return false; } + struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); + // set the input { const auto & mel_inp = wstate.mel; const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx; - struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); - assert(mel->type == GGML_TYPE_F32); assert(mel_inp.n_mel == wctx.model.hparams.n_mels); @@ -2068,6 +2039,12 @@ static bool whisper_encode_internal( if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { return false; } + } else { +#if defined(WHISPER_USE_COREML) + whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data); +#elif defined(WHISPER_USE_OPENVINO) + whisper_openvino_encode(wstate.ctx_openvino, mel, wstate.embd_enc); +#endif } }