add dtw to server

pull/1983/head
Emmanuel Schmidbauer 2024-03-21 12:32:36 -04:00
parent 84d34cdf46
commit b40cb896ad
1 changed files with 43 additions and 0 deletions

View File

@ -89,6 +89,8 @@ namespace
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
std::string openvino_encode_device = "CPU";
std::string dtw = "";
};
void whisper_print_usage(int /*argc*/, char **argv, const whisper_params &params, const server_params &sparams)
@ -128,6 +130,7 @@ namespace
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
// server params
fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str());
fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port);
@ -269,6 +272,10 @@ namespace
{
params.openvino_encode_device = argv[++i];
}
else if (arg == "-dtw" || arg == "--dtw")
{
params.dtw = argv[++i];
}
else if (arg == "-ng" || arg == "--no-gpu")
{
params.use_gpu = false;
@ -658,6 +665,41 @@ int main(int argc, char **argv)
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
if (!params.dtw.empty())
{
cparams.dtw_token_timestamps = true;
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
if (params.dtw == "tiny")
cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY;
if (params.dtw == "tiny.en")
cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN;
if (params.dtw == "base")
cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
if (params.dtw == "base.en")
cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN;
if (params.dtw == "small")
cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL;
if (params.dtw == "small.en")
cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN;
if (params.dtw == "medium")
cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM;
if (params.dtw == "medium.en")
cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN;
if (params.dtw == "large.v1")
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1;
if (params.dtw == "large.v2")
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2;
if (params.dtw == "large.v3")
cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE)
{
fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
return 3;
}
}
struct whisper_context *ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr)
@ -1026,6 +1068,7 @@ int main(int argc, char **argv)
if (!params.no_timestamps) {
word["start"] = token.t0 * 0.01;
word["end"] = token.t1 * 0.01;
word["t_dtw"] = token.t_dtw;
}
word["probability"] = token.p;
total_logprob += token.plog;