From 03f7e335604b3d68f74995aa2ccb4955833ee423 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 25 Mar 2023 20:51:14 +0200 Subject: [PATCH] Cleanup STL headers + fix embedding examples + minor stuff --- examples/embedding/embedding.cpp | 15 +++++---------- examples/perplexity/perplexity.cpp | 8 -------- llama.cpp | 22 ++++++++++++++-------- llama.h | 1 + 4 files changed, 20 insertions(+), 26 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 3015293f7..d397f35fd 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -1,15 +1,6 @@ #include "common.h" #include "llama.h" -#include -#include -#include -#include -#include -#include -#include -#include - int main(int argc, char ** argv) { gpt_params params; params.model = "models/llama-7B/ggml-model.bin"; @@ -94,9 +85,13 @@ int main(int argc, char ** argv) { } } + const int n_embd = llama_n_embd(ctx); const auto embeddings = llama_get_embeddings(ctx); - // TODO: print / use the embeddings + for (int i = 0; i < n_embd; i++) { + printf("%f ", embeddings[i]); + } + printf("\n"); } llama_print_timings(ctx); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index f0266a01f..f617ba365 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1,14 +1,6 @@ #include "common.h" #include "llama.h" -#include -#include -#include -#include -#include -#include -#include - std::vector softmax(const std::vector& logits) { std::vector probs(logits.size()); float max_logit = logits[0]; diff --git a/llama.cpp b/llama.cpp index 0015edec1..2bd520353 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1261,10 +1261,10 @@ static llama_vocab::id llama_sample_top_p_top_k( double repeat_penalty) { auto & rng = lctx.rng; - const auto & vocab = lctx.vocab; - const auto & logits = lctx.logits; + const int n_logits = lctx.model.hparams.n_vocab; - int n_logits = vocab.id_to_token.size(); + const auto & logits = lctx.logits; + const auto * plogits = logits.data() + logits.size() - n_logits; std::vector> logits_id; logits_id.reserve(n_logits); @@ -1276,13 +1276,13 @@ static llama_vocab::id llama_sample_top_p_top_k( // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability - if (logits[i] < 0.0) { - logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i)); + if (plogits[i] < 0.0) { + logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); } else { - logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i)); + logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); } } else { - logits_id.push_back(std::make_pair(logits[i]*scale, i)); + logits_id.push_back(std::make_pair(plogits[i]*scale, i)); } } } @@ -1677,6 +1677,8 @@ struct llama_context * llama_init_from_file( } const auto & hparams = ctx->model.hparams; + + // resized during inference if (params.logits_all) { ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); } else { @@ -1684,7 +1686,7 @@ struct llama_context * llama_init_from_file( } if (params.embedding){ - ctx->embedding.reserve(hparams.n_embd); + ctx->embedding.resize(hparams.n_embd); } ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type)); @@ -1761,6 +1763,10 @@ int llama_n_ctx(struct llama_context * ctx) { return ctx->model.hparams.n_ctx; } +int llama_n_embd(struct llama_context * ctx) { + return ctx->model.hparams.n_embd; +} + float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } diff --git a/llama.h b/llama.h index 827abc1f2..ebf55f41c 100644 --- a/llama.h +++ b/llama.h @@ -109,6 +109,7 @@ extern "C" { LLAMA_API int llama_n_vocab(struct llama_context * ctx); LLAMA_API int llama_n_ctx (struct llama_context * ctx); + LLAMA_API int llama_n_embd (struct llama_context * ctx); // Token logits obtained from the last call to llama_eval() // The logits for the last token are stored in the last row