whisper : allow non-CoreML fallback when Core ML cannot be loaded (#812)

if the Core ML model cannot be loaded, continue without Core ML instead of
returning. This allows a single build to transcribe using Core ML models
where available, and regular models when not.
pull/832/head
Canis Lupus 2023-04-29 08:49:02 +01:00 committed by GitHub
parent 3e82ff4747
commit 94a7cd2a07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 294 additions and 279 deletions

View File

@ -592,7 +592,7 @@ struct whisper_state {
std::string path_model; // populated by whisper_init_from_file()
#ifdef WHISPER_USE_COREML
whisper_coreml_context * ctx_coreml;
whisper_coreml_context * ctx_coreml = nullptr;
#endif
// [EXPERIMENTAL] token-level timestamps data
@ -1385,9 +1385,16 @@ static bool whisper_encode_internal(
}
}
#ifndef WHISPER_USE_COREML
struct ggml_tensor * cur;
#ifndef WHISPER_USE_COREML
const bool use_coreml = false;
#else
const bool use_coreml = wstate.ctx_coreml != nullptr;
#endif
if (!use_coreml)
{
// convolution + gelu
{
wstate.use_buf(ctx0, 1);
@ -1497,7 +1504,7 @@ static bool whisper_encode_internal(
wstate.use_buf(ctx0, 0);
#ifdef WHISPER_USE_FLASH_ATTN
#ifdef WHISPER_USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
@ -1522,7 +1529,7 @@ static bool whisper_encode_internal(
ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head));
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
#else
#else
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
@ -1568,7 +1575,7 @@ static bool whisper_encode_internal(
);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
#endif
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
wstate.use_buf(ctx0, 1);
@ -1618,13 +1625,13 @@ static bool whisper_encode_internal(
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
}
#ifdef WHISPER_USE_FLASH_FF
#ifdef WHISPER_USE_FLASH_FF
wstate.use_buf(ctx0, 0);
cur = ggml_flash_ff(ctx0,
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.wtype, n_state, n_ctx)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else
#else
wstate.use_buf(ctx0, 0);
// fully connected
@ -1655,7 +1662,7 @@ static bool whisper_encode_internal(
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.mlp_1_b, cur),
cur);
#endif
#endif
}
wstate.use_buf(ctx0, 3);
@ -1693,12 +1700,16 @@ static bool whisper_encode_internal(
//ggml_graph_print(&gf);
}
#else
}
#ifdef WHISPER_USE_COREML
else
{
wstate.use_buf(ctx0, -1);
struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
}
#endif
// cur
@ -2569,10 +2580,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
if (!state->ctx_coreml) {
fprintf(stderr, "%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
#ifndef WHISPER_COREML_ALLOW_FALLBACK
return nullptr;
}
#endif
} else {
fprintf(stderr, "%s: Core ML model loaded\n", __func__);
}
#endif
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
@ -2745,8 +2758,10 @@ void whisper_free_state(struct whisper_state * state)
}
#ifdef WHISPER_USE_COREML
if (state->ctx_coreml != nullptr) {
whisper_coreml_free(state->ctx_coreml);
state->ctx_coreml = nullptr;
}
#endif
delete state;