diff --git a/common/common.cpp b/common/common.cpp index 2a83b379e..88a962ae3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -417,6 +417,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.antiprompt.push_back(argv[i]); } else if (arg == "--perplexity") { params.perplexity = true; + } else if (arg == "--ppl-stride") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.ppl_stride = std::stoi(argv[i]); + } else if (arg == "--ppl-output-type") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.ppl_output_type = std::stoi(argv[i]); } else if (arg == "--hellaswag") { params.hellaswag = true; } else if (arg == "--hellaswag-tasks") { diff --git a/common/common.h b/common/common.h index 18fd951ea..d68a8ef88 100644 --- a/common/common.h +++ b/common/common.h @@ -64,6 +64,10 @@ struct gpt_params { std::string lora_adapter = ""; // lora adapter path std::string lora_base = ""; // base model path for the lora adapter + int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. + int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line + // (which is more convenient to use for plotting) + // bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index f3c045aec..e89725efc 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -27,7 +27,121 @@ std::vector softmax(const std::vector& logits) { return probs; } +void perplexity_v2(llama_context * ctx, const gpt_params & params) { + + // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research + // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` + // Output: `perplexity: 13.5106 [114/114]` + // BOS tokens will be added for each chunk before eval + + if (params.ppl_stride <= 0) { + fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride); + return; + } + auto tokens = ::llama_tokenize(ctx, params.prompt, true); + + const int calc_chunk = params.n_ctx; + + fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk); + + if (int(tokens.size()) <= calc_chunk) { + fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__, + tokens.size(), params.n_ctx, params.ppl_stride); + return; + } + + const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride; + + const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); + const int n_vocab = llama_n_vocab(ctx); + const int n_batch = params.n_batch; + + int count = 0; + double nll = 0.0; + + fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); + + for (int i = 0; i < n_chunk; ++i) { + const int start = i * params.ppl_stride; + const int end = start + calc_chunk; + + const int num_batches = (calc_chunk + n_batch - 1) / n_batch; + //fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches); + + std::vector logits; + + const auto t_start = std::chrono::high_resolution_clock::now(); + + for (int j = 0; j < num_batches; ++j) { + const int batch_start = start + j * n_batch; + const int batch_size = std::min(end - batch_start, n_batch); + + //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); + if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { + //fprintf(stderr, "%s : failed to eval\n", __func__); + return; + } + + // save original token and restore it after eval + const auto token_org = tokens[batch_start]; + + // add BOS token for the first batch of each chunk + if (j == 0) { + tokens[batch_start] = llama_token_bos(ctx); + } + + const auto batch_logits = llama_get_logits(ctx); + logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + + if (j == 0) { + tokens[batch_start] = token_org; + } + } + + const auto t_end = std::chrono::high_resolution_clock::now(); + + if (i == 0) { + const float t_total = std::chrono::duration(t_end - t_start).count(); + fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); + int total_seconds = (int)(t_total * n_chunk); + if (total_seconds >= 60*60) { + fprintf(stderr, "%d hours ", total_seconds / (60*60)); + total_seconds = total_seconds % (60*60); + } + fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); + } + + //fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start); + for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) { + + // Calculate probability of next token, given the previous ones. + const std::vector tok_logits( + logits.begin() + (j + 0) * n_vocab, + logits.begin() + (j + 1) * n_vocab); + + const float prob = softmax(tok_logits)[tokens[start + j + 1]]; + + nll += -std::log(prob); + ++count; + } + // perplexity is e^(average negative log-likelihood) + if (params.ppl_output_type == 0) { + printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + } else { + printf("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count)); + } + fflush(stdout); + } + printf("\n"); +} + void perplexity(llama_context * ctx, const gpt_params & params) { + + if (params.ppl_stride > 0) { + perplexity_v2(ctx, params); + return; + } + // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` @@ -116,7 +230,11 @@ void perplexity(llama_context * ctx, const gpt_params & params) { ++count; } // perplexity is e^(average negative log-likelihood) - printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + if (params.ppl_output_type == 0) { + printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + } else { + printf("%8d %.4lf\n", i*params.n_ctx, std::exp(nll / count)); + } fflush(stdout); } printf("\n"); @@ -369,6 +487,12 @@ int main(int argc, char ** argv) { params.perplexity = true; params.n_batch = std::min(params.n_batch, params.n_ctx); + if (params.ppl_stride > 0) { + fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n", + params.n_ctx, params.n_ctx + params.ppl_stride/2); + params.n_ctx += params.ppl_stride/2; + } + if (params.n_ctx > 2048) { fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);" "expect poor results\n", __func__, params.n_ctx);