From 332bdfd7980718abf664bfa5460f2288a3314984 Mon Sep 17 00:00:00 2001 From: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> Date: Mon, 11 Mar 2024 17:09:32 +0900 Subject: [PATCH] server : maintain chat completion id for streaming responses (#5988) * server: maintain chat completion id for streaming responses * Update examples/server/utils.hpp * Update examples/server/utils.hpp --------- Co-authored-by: Georgi Gerganov --- examples/server/server.cpp | 7 ++++--- examples/server/utils.hpp | 12 ++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c7d3ed01b..3951507aa 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3195,11 +3195,12 @@ int main(int argc, char ** argv) { ctx_server.queue_results.add_waiting_task_id(id_task); ctx_server.request_completion(id_task, -1, data, false, false); + const auto completion_id = gen_chatcmplid(); if (!json_value(data, "stream", false)) { server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error && result.stop) { - json result_oai = format_final_response_oaicompat(data, result.data); + json result_oai = format_final_response_oaicompat(data, result.data, completion_id); res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); } else { @@ -3208,11 +3209,11 @@ int main(int argc, char ** argv) { } ctx_server.queue_results.remove_waiting_task_id(id_task); } else { - const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) { + const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { while (true) { server_task_result result = ctx_server.queue_results.recv(id_task); if (!result.error) { - std::vector result_array = format_partial_response_oaicompat(result.data); + std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); for (auto it = result_array.begin(); it != result_array.end(); ++it) { if (!it->empty()) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index df0a27782..f27af81e9 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -378,7 +378,7 @@ static json oaicompat_completion_params_parse( return llama_params; } -static json format_final_response_oaicompat(const json & request, json result, bool streaming = false) { +static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) { bool stopped_word = result.count("stopped_word") != 0; bool stopped_eos = json_value(result, "stopped_eos", false); int num_tokens_predicted = json_value(result, "tokens_predicted", 0); @@ -412,7 +412,7 @@ static json format_final_response_oaicompat(const json & request, json result, b {"prompt_tokens", num_prompt_tokens}, {"total_tokens", num_tokens_predicted + num_prompt_tokens} }}, - {"id", gen_chatcmplid()} + {"id", completion_id} }; if (server_verbose) { @@ -427,7 +427,7 @@ static json format_final_response_oaicompat(const json & request, json result, b } // return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat(json result) { +static std::vector format_partial_response_oaicompat(json result, const std::string & completion_id) { if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { return std::vector({result}); } @@ -471,7 +471,7 @@ static std::vector format_partial_response_oaicompat(json result) { {"role", "assistant"} }}}})}, {"created", t}, - {"id", gen_chatcmplid()}, + {"id", completion_id}, {"model", modelname}, {"object", "chat.completion.chunk"}}; @@ -482,7 +482,7 @@ static std::vector format_partial_response_oaicompat(json result) { {"content", content}}} }})}, {"created", t}, - {"id", gen_chatcmplid()}, + {"id", completion_id}, {"model", modelname}, {"object", "chat.completion.chunk"}}; @@ -509,7 +509,7 @@ static std::vector format_partial_response_oaicompat(json result) { json ret = json { {"choices", choices}, {"created", t}, - {"id", gen_chatcmplid()}, + {"id", completion_id}, {"model", modelname}, {"object", "chat.completion.chunk"} };