diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6f3ca6b..8b6e469 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -397,6 +397,13 @@ std::string output_str(struct whisper_context * ctx, const whisper_params & para return result.str(); } +bool parse_str_to_bool(const std::string & s) { + if (s == "true" || s == "1" || s == "yes" || s == "y") { + return true; + } + return false; +} + void get_req_parameters(const Request & req, whisper_params & params) { if (req.has_file("offset_t")) @@ -415,6 +422,62 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.max_context = std::stoi(req.get_file_value("max_context").content); } + if (req.has_file("max_len")) + { + params.max_len = std::stoi(req.get_file_value("max_len").content); + } + if (req.has_file("best_of")) + { + params.best_of = std::stoi(req.get_file_value("best_of").content); + } + if (req.has_file("beam_size")) + { + params.beam_size = std::stoi(req.get_file_value("beam_size").content); + } + if (req.has_file("word_thold")) + { + params.word_thold = std::stof(req.get_file_value("word_thold").content); + } + if (req.has_file("entropy_thold")) + { + params.entropy_thold = std::stof(req.get_file_value("entropy_thold").content); + } + if (req.has_file("logprob_thold")) + { + params.logprob_thold = std::stof(req.get_file_value("logprob_thold").content); + } + if (req.has_file("debug_mode")) + { + params.debug_mode = parse_str_to_bool(req.get_file_value("debug_mode").content); + } + if (req.has_file("translate")) + { + params.translate = parse_str_to_bool(req.get_file_value("translate").content); + } + if (req.has_file("diarize")) + { + params.diarize = parse_str_to_bool(req.get_file_value("diarize").content); + } + if (req.has_file("tinydiarize")) + { + params.tinydiarize = parse_str_to_bool(req.get_file_value("tinydiarize").content); + } + if (req.has_file("split_on_word")) + { + params.split_on_word = parse_str_to_bool(req.get_file_value("split_on_word").content); + } + if (req.has_file("no_timestamps")) + { + params.no_timestamps = parse_str_to_bool(req.get_file_value("no_timestamps").content); + } + if (req.has_file("language")) + { + params.language = req.get_file_value("language").content; + } + if (req.has_file("detect_language")) + { + params.detect_language = parse_str_to_bool(req.get_file_value("detect_language").content); + } if (req.has_file("prompt")) { params.prompt = req.get_file_value("prompt").content; @@ -482,6 +545,9 @@ int main(int argc, char ** argv) { std::string const default_content = "hello"; + // store default params so we can reset after each inference request + whisper_params default_params = params; + // this is only called if no index.html is found in the public --path svr.Get(sparams.request_path + "/", [&default_content](const Request &, Response &res){ res.set_content(default_content, "text/html"); @@ -724,6 +790,9 @@ int main(int argc, char ** argv) { "application/json"); } + // reset params to thier defaults + params = default_params; + // return whisper model mutex lock whisper_mutex.unlock(); });