From 5e9ff54a675d163d9f42aad1b5b3e734f17b2701 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Sun, 20 Aug 2023 16:44:46 +0300 Subject: [PATCH] More efficient Hellaswag implementation (#2677) Co-authored-by: Iwan Kawrakow --- examples/perplexity/perplexity.cpp | 92 +++++++++++++++++++++++------- 1 file changed, 70 insertions(+), 22 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index b9b28a20b..682c39b16 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -209,50 +210,97 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) { double acc = 0.0f; const int n_vocab = llama_n_vocab(ctx); + std::vector tok_logits(n_vocab); + for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) { // Tokenize the context to count tokens std::vector context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos); size_t context_size = context_embd.size(); - for (size_t ending_idx=0;ending_idx<4;ending_idx++) { + // Do the 1st ending + // In this case we include the context when evaluating + auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos); + auto query_size = query_embd.size(); + //printf("First query: %d\n",(int)query_size); + + // Stop if query wont fit the ctx window + if (query_size > (size_t)params.n_ctx) { + fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size); + return; + } + + // Speedup small evaluations by evaluating atleast 32 tokens + if (query_size < 32) { + query_embd.resize(32); + } + + // Evaluate the query + if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return; + } + + auto query_logits = llama_get_logits(ctx); + + std::memcpy(tok_logits.data(), query_logits + (context_size-1)*n_vocab, n_vocab*sizeof(float)); + const auto first_probs = softmax(tok_logits); + + hs_data[task_idx].ending_logprob_count[0] = 1; + hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]); + + // Calculate the logprobs over the ending + for (size_t j = context_size; j < query_size - 1; j++) { + + std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float)); + + const float prob = softmax(tok_logits)[query_embd[j + 1]]; + + hs_data[task_idx].ending_logprob[0] += std::log(prob); + hs_data[task_idx].ending_logprob_count[0]++; + } + + // Calculate the mean token logprob for acc_norm + hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0]; + + // Do the remaining endings + // For these, we use the bare ending with n_past = context_size + // + for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) { // Tokenize the query - std::vector query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[ending_idx], prepend_bos); - size_t query_size = query_embd.size(); + query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false); + query_size = query_embd.size(); + //printf("Second query: %d\n",(int)query_size); // Stop if query wont fit the ctx window - if (query_size > (size_t)params.n_ctx) { + if (context_size + query_size > (size_t)params.n_ctx) { fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size); return; } // Speedup small evaluations by evaluating atleast 32 tokens - if (query_size < 32) { - query_embd.resize(32); - } + // No, resizing to 32 is actually slightly slower (at least on CUDA) + //if (query_size < 32) { + // query_embd.resize(32); + //} // Evaluate the query - if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) { + if (llama_eval(ctx, query_embd.data(), query_embd.size(), context_size, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return; } - const auto query_logits = llama_get_logits(ctx); - std::vector logits; - logits.insert(logits.end(), query_logits, query_logits + query_size * n_vocab); + query_logits = llama_get_logits(ctx); - hs_data[task_idx].ending_logprob_count[ending_idx] = 0; - hs_data[task_idx].ending_logprob[ending_idx] = 0.0f; + hs_data[task_idx].ending_logprob_count[ending_idx] = 1; + hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]); // Calculate the logprobs over the ending - for (size_t j = context_size-1; j < query_size - 1; j++) { - // Calculate probability of next token, given the previous ones. - const std::vector tok_logits( - logits.begin() + (j + 0) * n_vocab, - logits.begin() + (j + 1) * n_vocab); + for (size_t j = 0; j < query_size - 1; j++) { + std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float)); - const float prob = softmax(tok_logits)[query_embd[ j + 1]]; + const float prob = softmax(tok_logits)[query_embd[j + 1]]; hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob); hs_data[task_idx].ending_logprob_count[ending_idx]++; @@ -267,9 +315,9 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) { } // Find the ending with maximum logprob - size_t ending_logprob_max_idx = -1; - double ending_logprob_max_val = -INFINITY; - for (size_t j=0; j < 4; j++) { + size_t ending_logprob_max_idx = 0; + double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0]; + for (size_t j = 1; j < 4; j++) { if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) { ending_logprob_max_idx = j; ending_logprob_max_val = hs_data[task_idx].ending_logprob[j];