grammar : pre-computed pieces + reserve mem + less string copies (#4330)

* reserve space for codepoints

* improvement for the appended 0

* used precomputed token text for grammar sample

* reserve canidates_decoded

* reserve canidates_grammar

* remove candidates_decoded

* Revert "remove candidates_decoded"

This reverts commit 3773328080.

* changed decode_utf8 to take src by ref
This commit is contained in:
Marcus Dunn 2023-12-05 10:55:12 -10:00 committed by GitHub
parent 5aa365d88f
commit 5f6e0c0dff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -6851,14 +6851,13 @@ struct llama_grammar_candidate {
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8( static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const char * src, const std::string & src,
size_t n_src,
llama_partial_utf8 partial_start) { llama_partial_utf8 partial_start) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
const char * pos = src; const char * pos = src.c_str();
std::vector<uint32_t> code_points; std::vector<uint32_t> code_points;
// common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0. // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
code_points.reserve(n_src + 1); code_points.reserve(src.size() + 1);
uint32_t value = partial_start.value; uint32_t value = partial_start.value;
int n_remain = partial_start.n_remain; int n_remain = partial_start.n_remain;
@ -6909,13 +6908,6 @@ static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
} }
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
std::string src,
llama_partial_utf8 partial_start
) {
return decode_utf8(src.c_str(), src.size(), partial_start);
}
// returns true iff pos points to the end of one of the definitions of a rule // returns true iff pos points to the end of one of the definitions of a rule
static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
switch (pos->type) { switch (pos->type) {
@ -7554,11 +7546,13 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
const llama_token eos = llama_token_eos(&ctx->model); const llama_token eos = llama_token_eos(&ctx->model);
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded; std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
candidates_decoded.reserve(candidates->size);
std::vector<llama_grammar_candidate> candidates_grammar; std::vector<llama_grammar_candidate> candidates_grammar;
candidates_grammar.reserve(candidates->size);
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id; const llama_token id = candidates->data[i].id;
const std::string piece = llama_token_to_piece(ctx, id); const std::string & piece = ctx->model.vocab.id_to_token[id].text;
if (id == eos) { if (id == eos) {
if (!allow_eos) { if (!allow_eos) {
candidates->data[i].logit = -INFINITY; candidates->data[i].logit = -INFINITY;
@ -7770,7 +7764,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false); GGML_ASSERT(false);
} }
const std::string piece = llama_token_to_piece(ctx, token); const std::string & piece = ctx->model.vocab.id_to_token[token].text;
// Note terminating 0 in decoded string // Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8); const auto decoded = decode_utf8(piece, grammar->partial_utf8);