server : fix handling of characters that span multiple tokens when streaming (#4446)

This commit is contained in:
shibe2 2023-12-13 23:57:15 +04:00 committed by GitHub
parent 4d98d9a656
commit 948ff137ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -376,7 +376,6 @@ struct llama_client_slot
int32_t num_prompt_tokens = 0; int32_t num_prompt_tokens = 0;
int32_t num_prompt_tokens_processed = 0; int32_t num_prompt_tokens_processed = 0;
int32_t multibyte_pending = 0;
json prompt; json prompt;
std::string generated_text; std::string generated_text;
@ -425,7 +424,6 @@ struct llama_client_slot
stopped_word = false; stopped_word = false;
stopped_limit = false; stopped_limit = false;
stopping_word = ""; stopping_word = "";
multibyte_pending = 0;
n_past = 0; n_past = 0;
sent_count = 0; sent_count = 0;
sent_token_probs_index = 0; sent_token_probs_index = 0;
@ -992,35 +990,36 @@ struct llama_server_context
slot.generated_text += token_str; slot.generated_text += token_str;
slot.has_next_token = true; slot.has_next_token = true;
if (slot.multibyte_pending > 0) // check if there is incomplete UTF-8 character at the end
bool incomplete = false;
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i)
{ {
slot.multibyte_pending -= token_str.size(); unsigned char c = slot.generated_text[slot.generated_text.size() - i];
} if ((c & 0xC0) == 0x80)
else if (token_str.size() == 1) {
{ // continuation byte: 10xxxxxx
const char c = token_str[0]; continue;
// 2-byte characters: 110xxxxx 10xxxxxx }
if ((c & 0xE0) == 0xC0) if ((c & 0xE0) == 0xC0)
{ {
slot.multibyte_pending = 1; // 2-byte character: 110xxxxx ...
// 3-byte characters: 1110xxxx 10xxxxxx 10xxxxxx incomplete = i < 2;
} }
else if ((c & 0xF0) == 0xE0) else if ((c & 0xF0) == 0xE0)
{ {
slot.multibyte_pending = 2; // 3-byte character: 1110xxxx ...
// 4-byte characters: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx incomplete = i < 3;
} }
else if ((c & 0xF8) == 0xF0) else if ((c & 0xF8) == 0xF0)
{ {
slot.multibyte_pending = 3; // 4-byte character: 11110xxx ...
} incomplete = i < 4;
else
{
slot.multibyte_pending = 0;
} }
// else 1-byte character or invalid byte
break;
} }
if (slot.multibyte_pending == 0) if (!incomplete)
{ {
size_t pos = std::min(slot.sent_count, slot.generated_text.size()); size_t pos = std::min(slot.sent_count, slot.generated_text.size());
const std::string str_test = slot.generated_text.substr(pos); const std::string str_test = slot.generated_text.substr(pos);
@ -1055,7 +1054,7 @@ struct llama_server_context
} }
} }
if (slot.multibyte_pending > 0 && !slot.has_next_token) if (incomplete)
{ {
slot.has_next_token = true; slot.has_next_token = true;
} }