From b08e75baea294e366628b898e85c0bd359b58115 Mon Sep 17 00:00:00 2001 From: goerch Date: Sat, 16 Sep 2023 13:41:33 +0200 Subject: [PATCH] Fixing the last deviations from sentencepiece indicated by test-tokenizer-1 (#3170) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix für #2721 * Reenable tokenizer test for LLaMa * Add `console.cpp` dependency * Fix dependency to `common` * Fixing wrong fix. * Make console usage platform specific Work on compiler warnings. * Adapting makefile * Remove trailing whitespace * Adapting the other parts of the makefile * Fix typo. * Fixing the last deviations from sentencepiece indicated by test-tokenizer-1 * Simplify logic * Add missing change... * Fix ugly compiler warning * llama_tokenize should accept strings containing NUL now * Adding huichen's test case --- common/common.cpp | 4 ++-- .../train-text-from-scratch.cpp | 4 ++-- llama.cpp | 6 ++++-- llama.h | 2 ++ tests/test-tokenizer-0-llama.cpp | 1 + tests/test-tokenizer-1-llama.cpp | 14 ++++++-------- 6 files changed, 17 insertions(+), 14 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 02ec0f8d0..6d655fd55 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -801,10 +801,10 @@ std::vector llama_tokenize( // upper limit for the number of tokens int n_tokens = text.length() + add_bos; std::vector result(n_tokens); - n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 947aa7ed3..59c90c7ba 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -965,10 +965,10 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto buf[size] = '\0'; - int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false); + int n_tokens = llama_tokenize(lctx, buf.data(), buf.size(), out.data(), out.size(), false); if (n_tokens < 0) { out.resize(-n_tokens); - n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false); + n_tokens = llama_tokenize(lctx, buf.data(), buf.size(), out.data(), out.size(), false); } GGML_ASSERT(n_tokens >= 0); out.resize(n_tokens); diff --git a/llama.cpp b/llama.cpp index a65026122..0b334b4e9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7032,19 +7032,21 @@ llama_token llama_token_nl(const struct llama_context * ctx) { int llama_tokenize( struct llama_context * ctx, const char * text, + int text_len, llama_token * tokens, int n_max_tokens, bool add_bos) { - return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos); + return llama_tokenize_with_model(&ctx->model, text, text_len, tokens, n_max_tokens, add_bos); } int llama_tokenize_with_model( const struct llama_model * model, const char * text, + int text_len, llama_token * tokens, int n_max_tokens, bool add_bos) { - auto res = llama_tokenize_internal(model->vocab, text, add_bos); + auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos); if (n_max_tokens < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); diff --git a/llama.h b/llama.h index c6ee038c7..369be048c 100644 --- a/llama.h +++ b/llama.h @@ -374,6 +374,7 @@ extern "C" { LLAMA_API int llama_tokenize( struct llama_context * ctx, const char * text, + int text_len, llama_token * tokens, int n_max_tokens, bool add_bos); @@ -381,6 +382,7 @@ extern "C" { LLAMA_API int llama_tokenize_with_model( const struct llama_model * model, const char * text, + int text_len, llama_token * tokens, int n_max_tokens, bool add_bos); diff --git a/tests/test-tokenizer-0-llama.cpp b/tests/test-tokenizer-0-llama.cpp index edbd86f85..dfb2e81a9 100644 --- a/tests/test-tokenizer-0-llama.cpp +++ b/tests/test-tokenizer-0-llama.cpp @@ -36,6 +36,7 @@ static const std::map> & k_tests() { { " Hello" , { 1678, 15043, }, }, { " Hello" , { 268, 15043, }, }, { " Hello\n Hello" , { 268, 15043, 13, 1678, 15043, }, }, + { " (" , { 29871, 313, }, }, }; return _k_tests; diff --git a/tests/test-tokenizer-1-llama.cpp b/tests/test-tokenizer-1-llama.cpp index 804ea2486..a95d462cf 100644 --- a/tests/test-tokenizer-1-llama.cpp +++ b/tests/test-tokenizer-1-llama.cpp @@ -87,10 +87,9 @@ int main(int argc, char **argv) { std::vector tokens = llama_tokenize(ctx, str, false); std::string check = llama_detokenize_spm(ctx, tokens); if (check != str) { - fprintf(stderr, "%s : error: token %d detokenizes to >%s<(%llu) but tokenization of this detokenizes to >%s<(%llu)\n", + fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n", __func__, i, str.c_str(), str.length(), check.c_str(), check.length()); - if(i != 3) - return 2; + return 2; } } @@ -99,11 +98,10 @@ int main(int argc, char **argv) { std::string str = codepoint_to_utf8(cp); std::vector tokens = llama_tokenize(ctx, str, false); std::string check = llama_detokenize_spm(ctx, tokens); - if (str != check) { - fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n", + if (cp != 9601 && str != check) { + fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n", __func__, cp, check.c_str(), check.length(), str.c_str(), str.length()); - if(cp != 0 && cp != 9601) - return 3; + return 3; } } } @@ -112,7 +110,7 @@ int main(int argc, char **argv) { std::vector tokens = llama_tokenize(ctx, str, false); std::string check = llama_detokenize_spm(ctx, tokens); if (str != check) { - fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n", + fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n", __func__, cp, check.c_str(), check.length(), str.c_str(), str.length()); return 4; }