From 604b8bdfa6320bbcb018eebcc1252dfede603c6b Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Thu, 17 Aug 2023 19:54:44 -0400 Subject: [PATCH] Fix unicode in grammars (fixes #2501) (#2553) * Fix unicode in grammars (fixes #2501) * add more comments * fix test-llama-grammar --- llama.cpp | 161 +++++++++++++++++++++++++++++------ tests/test-llama-grammar.cpp | 2 +- 2 files changed, 135 insertions(+), 28 deletions(-) diff --git a/llama.cpp b/llama.cpp index b8cc22942..e02b60596 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2077,37 +2077,81 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co // grammar - internal // +struct llama_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + struct llama_grammar { const std::vector> rules; std::vector> stacks; + + // buffer for partially generated UTF-8 sequence from accepted tokens + llama_partial_utf8 partial_utf8; }; struct llama_grammar_candidate { - size_t index; - const uint32_t * code_points; + size_t index; + const uint32_t * code_points; + llama_partial_utf8 partial_utf8; }; -// NOTE: assumes valid utf8 (but checks for overrun) -// adds a terminating 0 for use as pointer -std::vector decode_utf8(const char * src) { - static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; +// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as +// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. +std::pair, llama_partial_utf8> decode_utf8( + const char * src, + llama_partial_utf8 partial_start) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; const char * pos = src; std::vector code_points; + uint32_t value = partial_start.value; + int n_remain = partial_start.n_remain; + + // continue previous decode, if applicable + while (*pos != 0 && n_remain > 0) { + uint8_t next_byte = static_cast(*pos); + if ((next_byte >> 6) != 2) { + // invalid sequence, abort + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 }); + } + value = (value << 6) + (next_byte & 0x3F); + ++pos; + --n_remain; + } + + if (partial_start.n_remain > 0 && n_remain == 0) { + code_points.push_back(value); + } + + // decode any subsequent utf-8 sequences, which may end in an incomplete one while (*pos != 0) { uint8_t first_byte = static_cast(*pos); uint8_t highbits = first_byte >> 4; - int len = lookup[highbits]; - uint8_t mask = (1 << (8 - len)) - 1; - uint32_t value = first_byte & mask; - const char * end = pos + len; // may overrun! - ++pos; - for ( ; pos < end && *pos != 0; ++pos) { - value = (value << 6) + (static_cast(*pos) & 0x3F); + n_remain = lookup[highbits] - 1; + + if (n_remain < 0) { + // invalid sequence, abort + code_points.clear(); + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain }); + } + + uint8_t mask = (1 << (7 - n_remain)) - 1; + value = first_byte & mask; + ++pos; + while (*pos != 0 && n_remain > 0) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + ++pos; + --n_remain; + } + if (n_remain == 0) { + code_points.push_back(value); } - code_points.push_back(value); } code_points.push_back(0); - return code_points; + + return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); } // returns true iff pos points to the end of one of the definitions of a rule @@ -2144,6 +2188,56 @@ static std::pair llama_grammar_match_char( return std::make_pair(found == is_positive_char, pos); } +// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char +// range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static bool llama_grammar_match_partial_char( + const llama_grammar_element * pos, + const llama_partial_utf8 partial_utf8) { + + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); + + uint32_t partial_value = partial_utf8.value; + int n_remain = partial_utf8.n_remain; + + // invalid sequence or 7-bit char split across 2 bytes (overlong) + if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) { + return false; + } + + // range of possible code points this partial UTF-8 sequence could complete to + uint32_t low = partial_value << (n_remain * 6); + uint32_t high = low | ((1 << (n_remain * 6)) - 1); + + if (low == 0) { + if (n_remain == 2) { + low = 1 << 11; + } else if (n_remain == 3) { + low = 1 << 16; + } + } + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + if (pos->value <= high && low <= pos[1].value) { + return is_positive_char; + } + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + if (low <= pos->value && pos->value <= high) { + return is_positive_char; + } + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return !is_positive_char; +} + + // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void llama_grammar_advance_stack( @@ -2244,8 +2338,11 @@ static std::vector llama_grammar_reject_candidates_for_ std::vector rejects; if (stack.empty()) { - // accept nothing; EOS is handled elsewhere - rejects.insert(rejects.end(), candidates.begin(), candidates.end()); + for (auto tok : candidates) { + if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } return rejects; } @@ -2253,10 +2350,15 @@ static std::vector llama_grammar_reject_candidates_for_ std::vector next_candidates; for (auto tok : candidates) { - if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) { - if (tok.code_points[1] != 0) { - next_candidates.push_back({ tok.index, tok.code_points + 1 }); + if (*tok.code_points == 0) { + // reached end of full codepoints in token, reject iff it ended in a partial sequence + // that cannot satisfy this position in grammar + if (tok.partial_utf8.n_remain != 0 && + !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + rejects.push_back(tok); } + } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { + next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); } else { rejects.push_back(tok); } @@ -2274,7 +2376,7 @@ static std::vector llama_grammar_reject_candidates_for_ auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); for (auto tok : next_rejects) { - rejects.push_back({ tok.index, tok.code_points - 1 }); + rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); } return rejects; @@ -2339,7 +2441,7 @@ struct llama_grammar * llama_grammar_init( } } while (true); - return new llama_grammar{ std::move(vec_rules), std::move(stacks) }; + return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; } void llama_grammar_free(struct llama_grammar * grammar) { @@ -2645,8 +2747,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c const llama_token eos = llama_token_eos(); - std::vector> candidates_decoded; - std::vector candidates_grammar; + std::vector, llama_partial_utf8>> candidates_decoded; + std::vector candidates_grammar; for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; @@ -2658,8 +2760,10 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c } else if (*str == 0) { candidates->data[i].logit = -INFINITY; } else { - candidates_decoded.push_back(decode_utf8(str)); - candidates_grammar.push_back({ i, candidates_decoded.back().data() }); + candidates_decoded.push_back(decode_utf8(str, grammar->partial_utf8)); + candidates_grammar.push_back({ + i, candidates_decoded.back().first.data(), candidates_decoded.back().second + }); } } @@ -2860,11 +2964,14 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar } const char * str = llama_token_to_str(ctx, token); + // Note terminating 0 in decoded string - auto code_points = decode_utf8(str); + const auto decoded = decode_utf8(str, grammar->partial_utf8); + const auto & code_points = decoded.first; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); } + grammar->partial_utf8 = decoded.second; LLAMA_ASSERT(!grammar->stacks.empty()); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index f98c6531f..81c31e9e2 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -199,7 +199,7 @@ int main() uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point cp[0] = 37 + i; cp[1] = 0; - next_candidates[i] = {i, cp}; + next_candidates[i] = {i, cp, {}}; } std::vector>> expected_reject = {