From d046dcee081118c9071bbc63dacdb359a58c467a Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Fri, 25 Aug 2023 19:05:02 +0300 Subject: [PATCH] Faster perplexity computation (#2786) Co-authored-by: Iwan Kawrakow --- examples/perplexity/perplexity.cpp | 67 +++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 11 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index a7bd9db2a..18635932b 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -6,6 +6,8 @@ #include #include #include +#include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -27,6 +29,40 @@ std::vector softmax(const std::vector& logits) { return probs; } +float log_softmax(int n_vocab, const float * logits, int tok) { + float max_logit = logits[0]; + for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]); + double sum_exp = 0.0; + for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit); + return logits[tok] - max_logit - log(sum_exp); +} + +void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, std::vector& workers, + double& nll, double& nll2) { + + std::mutex mutex; + int counter = 0; + auto compute = [&mutex, &counter, &nll, &nll2, n_vocab, logits, tokens, n_token] () { + double local_nll = 0, local_nll2 = 0; + while (true) { + std::unique_lock lock(mutex); + int i = counter++; + if (i >= n_token) { + nll += local_nll; nll2 += local_nll2; + break; + } + lock.unlock(); + double v = -log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]); + local_nll += v; + local_nll2 += v*v; + } + }; + for (auto& w : workers) w = std::thread(compute); + compute(); + for (auto& w : workers) w.join(); + +} + void perplexity_v2(llama_context * ctx, const gpt_params & params) { // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` @@ -166,9 +202,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) { int count = 0; double nll = 0.0; + double nll2 = 0.0; fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); + std::vector workers(std::thread::hardware_concurrency() - 1); + for (int i = 0; i < n_chunk; ++i) { const int start = i * params.n_ctx; const int end = start + params.n_ctx; @@ -228,26 +267,32 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // Example, we have a context window of 512, we will compute perplexity for each of the // last 256 tokens. Then, we split the input up into context window size chunks to // process the entire prompt. - for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) { - // Calculate probability of next token, given the previous ones. - const std::vector tok_logits( - logits.begin() + (j + 0) * n_vocab, - logits.begin() + (j + 1) * n_vocab); + const int first = std::min(512, params.n_ctx/2); + process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first, workers, nll, nll2); + count += params.n_ctx - first - 1; - const float prob = softmax(tok_logits)[tokens[start + j + 1]]; - - nll += -std::log(prob); - ++count; - } // perplexity is e^(average negative log-likelihood) if (params.ppl_output_type == 0) { printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); } else { - printf("%8d %.4lf\n", i*params.n_ctx, std::exp(nll / count)); + double av = nll/count; + double av2 = nll2/count - av*av; + if (av2 > 0) av2 = sqrt(av2/(count-1)); + printf("%8d %.4lf %4lf %4lf\n", i*params.n_ctx, std::exp(nll / count), av, av2); } fflush(stdout); } printf("\n"); + nll2 /= count; + nll /= count; + nll2 -= nll * nll; + if (nll2 > 0) { + nll2 = sqrt(nll2/(count-1)); + double ppl = exp(nll); + printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); + } else { + printf("Unexpected negative standard deviation of log(prob)\n"); + } } std::vector hellaswag_evaluate_tokens(llama_context * ctx, const std::vector& tokens, int n_past, int n_batch,