llama.cpp/examples/perplexity/perplexity.cpp

728 lines
28 KiB
C++
Raw Normal View History

#include "common.h"
#include "llama.h"
#include "build-info.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <sstream>
#include <thread>
#include <mutex>
#include <vector>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
struct results_perplexity {
std::vector<llama_token> tokens;
double ppl_value;
std::vector<float> logits;
std::vector<float> probs;
};
struct results_log_softmax {
double log_softmax;
float logit;
float prob;
};
static void write_logfile(
const llama_context * ctx, const gpt_params & params, const llama_model * model,
const struct results_perplexity & results
) {
if (params.logdir.empty()) {
return;
}
if (params.hellaswag) {
fprintf(stderr, "%s: warning: logging results is not implemented for HellaSwag. No files will be written.\n", __func__);
return;
}
const std::string timestamp = get_sortable_timestamp();
const bool success = create_directory_with_parents(params.logdir);
if (!success) {
fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n",
__func__, params.logdir.c_str());
return;
}
const std::string logfile_path = params.logdir + timestamp + ".yml";
FILE * logfile = fopen(logfile_path.c_str(), "w");
if (logfile == NULL) {
fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str());
return;
}
fprintf(logfile, "binary: main\n");
char model_desc[128];
llama_model_desc(model, model_desc, sizeof(model_desc));
dump_non_result_info_yaml(logfile, params, ctx, timestamp, results.tokens, model_desc);
fprintf(logfile, "\n");
fprintf(logfile, "######################\n");
fprintf(logfile, "# Perplexity Results #\n");
fprintf(logfile, "######################\n");
fprintf(logfile, "\n");
dump_vector_float_yaml(logfile, "logits", results.logits);
fprintf(logfile, "ppl_value: %f\n", results.ppl_value);
dump_vector_float_yaml(logfile, "probs", results.probs);
llama_dump_timing_info_yaml(logfile, ctx);
fclose(logfile);
}
static std::vector<float> softmax(const std::vector<float>& logits) {
std::vector<float> probs(logits.size());
float max_logit = logits[0];
for (float v : logits) max_logit = std::max(max_logit, v);
double sum_exp = 0.0;
for (size_t i = 0; i < logits.size(); i++) {
// Subtract the maximum logit value from the current logit value for numerical stability
const float logit = logits[i] - max_logit;
const float exp_logit = expf(logit);
sum_exp += exp_logit;
probs[i] = exp_logit;
}
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
return probs;
}
static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
float max_logit = logits[0];
for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]);
double sum_exp = 0.0;
for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit);
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
}
static void process_logits(
int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
double & nll, double & nll2, float * logit_history, float * prob_history
) {
std::mutex mutex;
int counter = 0;
auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
double local_nll = 0, local_nll2 = 0;
while (true) {
std::unique_lock<std::mutex> lock(mutex);
int i = counter++;
if (i >= n_token) {
nll += local_nll; nll2 += local_nll2;
break;
}
lock.unlock();
const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
const double v = -results.log_softmax;
local_nll += v;
local_nll2 += v*v;
logit_history[i] = results.logit;
prob_history[i] = results.prob;
}
};
for (auto & w : workers) w = std::thread(compute);
compute();
for (auto & w : workers) w.join();
}
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval
llm : add Falcon support (#2717) * llama : refactor GGUF constants into static maps * llama : check if model architecture is known * llama : refactor llama_model_load_internal() * gguf : add KV constant maps * llm : read arch-specific KVs * convert : add dummy scores + types * falcon : load tensor data (CPU only) * llama : fix loading progress bar * llama : add arch member to llama_model * falcon : CPU inference working * falcon : support non-40B models * falcon : minor * llama : minor updates ggml-ci * convert-falcon-hf-to-gguf.py : fix special token mapping * llama.cpp : llama default UNK token = id 0 * llama.cpp : fix bpe tokenizer * llama.cpp : fix the fix of bpe tokenizer * ggml : pass eps to ggml_norm * metal : implement RoPE (mode = 2) + avoid ggml_repeat * ggml : ggml_repeat always creates new tensor * falcon : copy-paste self-attention from LLaMA * metal : print extra compute pipeline info * falcon : minor changes (still chasing the Metal problem) * llama.cpp : fix linefeed token * metal : fix GELU kernel numerical stability by using precise::tanh * metal : temporary workaround for the concurrency optimization bug * falcon : add CUDA offloading (#2739) * llama : better model naming and size reporting * llama : prep new tokenizer support * llama : advanced BPE tokenizer based on ggllm.cpp imlpementation * llama : remove oboslete comment ggml-ci * common : remove obsolete BPE API + disable test-tokenizer-1 * llama : revert BPE special-case in llama_byte_to_token() * cuda : add TODOs for RoPE NeoX implementation * llama : default special tokens based on vocab type * perplexity : add log for start of tokenization --------- Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> Co-authored-by: slaren <slarengh@gmail.com>
2023-08-23 22:08:04 +02:00
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
if (int(tokens.size()) < 2*params.n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx,
params.n_ctx);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return {std::move(tokens), 0., {}, {}};
}
std::vector<float> logit_history;
std::vector<float> prob_history;
logit_history.resize(tokens.size());
prob_history.resize(tokens.size());
if (params.ppl_stride <= 0) {
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
return {tokens, -1, logit_history, prob_history};
}
const int calc_chunk = params.n_ctx;
fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
if (int(tokens.size()) <= calc_chunk) {
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
tokens.size(), params.n_ctx, params.ppl_stride);
return {tokens, -1, logit_history, prob_history};
}
const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(ctx);
const int n_batch = params.n_batch;
int count = 0;
double nll = 0.0;
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
for (int i = 0; i < n_chunk; ++i) {
const int start = i * params.ppl_stride;
const int end = start + calc_chunk;
const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
//fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
std::vector<float> logits;
const auto t_start = std::chrono::high_resolution_clock::now();
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
//fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
// save original token and restore it after eval
const auto token_org = tokens[batch_start];
// add BOS token for the first batch of each chunk
llm : add Falcon support (#2717) * llama : refactor GGUF constants into static maps * llama : check if model architecture is known * llama : refactor llama_model_load_internal() * gguf : add KV constant maps * llm : read arch-specific KVs * convert : add dummy scores + types * falcon : load tensor data (CPU only) * llama : fix loading progress bar * llama : add arch member to llama_model * falcon : CPU inference working * falcon : support non-40B models * falcon : minor * llama : minor updates ggml-ci * convert-falcon-hf-to-gguf.py : fix special token mapping * llama.cpp : llama default UNK token = id 0 * llama.cpp : fix bpe tokenizer * llama.cpp : fix the fix of bpe tokenizer * ggml : pass eps to ggml_norm * metal : implement RoPE (mode = 2) + avoid ggml_repeat * ggml : ggml_repeat always creates new tensor * falcon : copy-paste self-attention from LLaMA * metal : print extra compute pipeline info * falcon : minor changes (still chasing the Metal problem) * llama.cpp : fix linefeed token * metal : fix GELU kernel numerical stability by using precise::tanh * metal : temporary workaround for the concurrency optimization bug * falcon : add CUDA offloading (#2739) * llama : better model naming and size reporting * llama : prep new tokenizer support * llama : advanced BPE tokenizer based on ggllm.cpp imlpementation * llama : remove oboslete comment ggml-ci * common : remove obsolete BPE API + disable test-tokenizer-1 * llama : revert BPE special-case in llama_byte_to_token() * cuda : add TODOs for RoPE NeoX implementation * llama : default special tokens based on vocab type * perplexity : add log for start of tokenization --------- Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> Co-authored-by: slaren <slarengh@gmail.com>
2023-08-23 22:08:04 +02:00
if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(ctx);
}
const auto batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
if (j == 0) {
tokens[batch_start] = token_org;
}
}
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
if (total_seconds >= 60*60) {
fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
}
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}
//fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) {
// Calculate probability of next token, given the previous ones.
const std::vector<float> tok_logits(
logits.begin() + (j + 0) * n_vocab,
logits.begin() + (j + 1) * n_vocab);
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]];
prob_history[start + j + 1] = prob;
nll += -std::log(prob);
++count;
}
// perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) {
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
} else {
printf("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
}
fflush(stdout);
}
printf("\n");
return {tokens, std::exp(nll / count), logit_history, prob_history};
}
static results_perplexity perplexity(llama_context * ctx, const gpt_params & params) {
if (params.ppl_stride > 0) {
return perplexity_v2(ctx, params);
}
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval
llm : add Falcon support (#2717) * llama : refactor GGUF constants into static maps * llama : check if model architecture is known * llama : refactor llama_model_load_internal() * gguf : add KV constant maps * llm : read arch-specific KVs * convert : add dummy scores + types * falcon : load tensor data (CPU only) * llama : fix loading progress bar * llama : add arch member to llama_model * falcon : CPU inference working * falcon : support non-40B models * falcon : minor * llama : minor updates ggml-ci * convert-falcon-hf-to-gguf.py : fix special token mapping * llama.cpp : llama default UNK token = id 0 * llama.cpp : fix bpe tokenizer * llama.cpp : fix the fix of bpe tokenizer * ggml : pass eps to ggml_norm * metal : implement RoPE (mode = 2) + avoid ggml_repeat * ggml : ggml_repeat always creates new tensor * falcon : copy-paste self-attention from LLaMA * metal : print extra compute pipeline info * falcon : minor changes (still chasing the Metal problem) * llama.cpp : fix linefeed token * metal : fix GELU kernel numerical stability by using precise::tanh * metal : temporary workaround for the concurrency optimization bug * falcon : add CUDA offloading (#2739) * llama : better model naming and size reporting * llama : prep new tokenizer support * llama : advanced BPE tokenizer based on ggllm.cpp imlpementation * llama : remove oboslete comment ggml-ci * common : remove obsolete BPE API + disable test-tokenizer-1 * llama : revert BPE special-case in llama_byte_to_token() * cuda : add TODOs for RoPE NeoX implementation * llama : default special tokens based on vocab type * perplexity : add log for start of tokenization --------- Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> Co-authored-by: slaren <slarengh@gmail.com>
2023-08-23 22:08:04 +02:00
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;
auto tim1 = std::chrono::high_resolution_clock::now();
llm : add Falcon support (#2717) * llama : refactor GGUF constants into static maps * llama : check if model architecture is known * llama : refactor llama_model_load_internal() * gguf : add KV constant maps * llm : read arch-specific KVs * convert : add dummy scores + types * falcon : load tensor data (CPU only) * llama : fix loading progress bar * llama : add arch member to llama_model * falcon : CPU inference working * falcon : support non-40B models * falcon : minor * llama : minor updates ggml-ci * convert-falcon-hf-to-gguf.py : fix special token mapping * llama.cpp : llama default UNK token = id 0 * llama.cpp : fix bpe tokenizer * llama.cpp : fix the fix of bpe tokenizer * ggml : pass eps to ggml_norm * metal : implement RoPE (mode = 2) + avoid ggml_repeat * ggml : ggml_repeat always creates new tensor * falcon : copy-paste self-attention from LLaMA * metal : print extra compute pipeline info * falcon : minor changes (still chasing the Metal problem) * llama.cpp : fix linefeed token * metal : fix GELU kernel numerical stability by using precise::tanh * metal : temporary workaround for the concurrency optimization bug * falcon : add CUDA offloading (#2739) * llama : better model naming and size reporting * llama : prep new tokenizer support * llama : advanced BPE tokenizer based on ggllm.cpp imlpementation * llama : remove oboslete comment ggml-ci * common : remove obsolete BPE API + disable test-tokenizer-1 * llama : revert BPE special-case in llama_byte_to_token() * cuda : add TODOs for RoPE NeoX implementation * llama : default special tokens based on vocab type * perplexity : add log for start of tokenization --------- Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> Co-authored-by: slaren <slarengh@gmail.com>
2023-08-23 22:08:04 +02:00
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
if (int(tokens.size()) < 2*params.n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx,
params.n_ctx);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return {std::move(tokens), 0., {}, {}};
}
std::vector<float> logit_history;
logit_history.resize(tokens.size());
std::vector<float> prob_history;
prob_history.resize(tokens.size());
const int n_chunk_max = tokens.size() / params.n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(ctx);
const int n_batch = params.n_batch;
int count = 0;
double nll = 0.0;
double nll2 = 0.0;
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
for (int i = 0; i < n_chunk; ++i) {
const int start = i * params.n_ctx;
const int end = start + params.n_ctx;
const int num_batches = (params.n_ctx + n_batch - 1) / n_batch;
std::vector<float> logits;
const auto t_start = std::chrono::high_resolution_clock::now();
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
// save original token and restore it after eval
const auto token_org = tokens[batch_start];
// add BOS token for the first batch of each chunk
llm : add Falcon support (#2717) * llama : refactor GGUF constants into static maps * llama : check if model architecture is known * llama : refactor llama_model_load_internal() * gguf : add KV constant maps * llm : read arch-specific KVs * convert : add dummy scores + types * falcon : load tensor data (CPU only) * llama : fix loading progress bar * llama : add arch member to llama_model * falcon : CPU inference working * falcon : support non-40B models * falcon : minor * llama : minor updates ggml-ci * convert-falcon-hf-to-gguf.py : fix special token mapping * llama.cpp : llama default UNK token = id 0 * llama.cpp : fix bpe tokenizer * llama.cpp : fix the fix of bpe tokenizer * ggml : pass eps to ggml_norm * metal : implement RoPE (mode = 2) + avoid ggml_repeat * ggml : ggml_repeat always creates new tensor * falcon : copy-paste self-attention from LLaMA * metal : print extra compute pipeline info * falcon : minor changes (still chasing the Metal problem) * llama.cpp : fix linefeed token * metal : fix GELU kernel numerical stability by using precise::tanh * metal : temporary workaround for the concurrency optimization bug * falcon : add CUDA offloading (#2739) * llama : better model naming and size reporting * llama : prep new tokenizer support * llama : advanced BPE tokenizer based on ggllm.cpp imlpementation * llama : remove oboslete comment ggml-ci * common : remove obsolete BPE API + disable test-tokenizer-1 * llama : revert BPE special-case in llama_byte_to_token() * cuda : add TODOs for RoPE NeoX implementation * llama : default special tokens based on vocab type * perplexity : add log for start of tokenization --------- Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> Co-authored-by: slaren <slarengh@gmail.com>
2023-08-23 22:08:04 +02:00
if (add_bos && j == 0) {
gguf : new file format with flexible meta data (beta) (#2398) * gguf : first API pass * gguf : read header + meta data * gguf : read tensor info * gguf : initial model loading - not tested * gguf : add gguf_get_tensor_name() * gguf : do not support passing existing ggml_context to gguf_init * gguf : simplify gguf_get_val * gguf : gguf.c is now part of ggml.c * gguf : read / write sample models * gguf : add comments * refactor : reduce code duplication and better API (#2415) * gguf : expose the gguf_type enum through the API for now * gguf : add array support * gguf.py : some code style changes * convert.py : start a new simplified implementation by removing old stuff * convert.py : remove GGML vocab + other obsolete stuff * GGUF : write tensor (#2426) * WIP: Write tensor * GGUF : Support writing tensors in Python * refactor : rm unused import and upd todos * fix : fix errors upd writing example * rm example.gguf * gitignore *.gguf * undo formatting * gguf : add gguf_find_key (#2438) * gguf.cpp : find key example * ggml.h : add gguf_find_key * ggml.c : add gguf_find_key * gguf : fix writing tensors * gguf : do not hardcode tensor names to read * gguf : write sample tensors to read * gguf : add tokenization constants * quick and dirty conversion example * gguf : fix writing gguf arrays * gguf : write tensors one by one and code reuse * gguf : fix writing gguf arrays * gguf : write tensors one by one * gguf : write tensors one by one * gguf : write tokenizer data * gguf : upd gguf conversion script * Update convert-llama-h5-to-gguf.py * gguf : handle already encoded string * ggml.h : get array str and f32 * ggml.c : get arr str and f32 * gguf.py : support any type * Update convert-llama-h5-to-gguf.py * gguf : fix set is not subscriptable * gguf : update convert-llama-h5-to-gguf.py * constants.py : add layer norm eps * gguf.py : add layer norm eps and merges * ggml.h : increase GGML_MAX_NAME to 64 * ggml.c : add gguf_get_arr_n * Update convert-llama-h5-to-gguf.py * add gptneox gguf example * Makefile : add gptneox gguf example * Update convert-llama-h5-to-gguf.py * add gptneox gguf example * Update convert-llama-h5-to-gguf.py * Update convert-gptneox-h5-to-gguf.py * Update convert-gptneox-h5-to-gguf.py * Update convert-llama-h5-to-gguf.py * gguf : support custom alignment value * gguf : fix typo in function call * gguf : mmap tensor data example * fix : update convert-llama-h5-to-gguf.py * Update convert-llama-h5-to-gguf.py * convert-gptneox-h5-to-gguf.py : Special tokens * gptneox-main.cpp : special tokens * Update gptneox-main.cpp * constants.py : special tokens * gguf.py : accumulate kv and tensor info data + special tokens * convert-gptneox-h5-to-gguf.py : accumulate kv and ti + special tokens * gguf : gguf counterpart of llama-util.h * gguf-util.h : update note * convert-llama-h5-to-gguf.py : accumulate kv / ti + special tokens * convert-llama-h5-to-gguf.py : special tokens * Delete gptneox-common.cpp * Delete gptneox-common.h * convert-gptneox-h5-to-gguf.py : gpt2bpe tokenizer * gptneox-main.cpp : gpt2 bpe tokenizer * gpt2 bpe tokenizer (handles merges and unicode) * Makefile : remove gptneox-common * gguf.py : bytesarray for gpt2bpe tokenizer * cmpnct_gpt2bpe.hpp : comments * gguf.py : use custom alignment if present * gguf : minor stuff * Update gptneox-main.cpp * map tensor names * convert-gptneox-h5-to-gguf.py : map tensor names * convert-llama-h5-to-gguf.py : map tensor names * gptneox-main.cpp : map tensor names * gguf : start implementing libllama in GGUF (WIP) * gguf : start implementing libllama in GGUF (WIP) * rm binary commited by mistake * upd .gitignore * gguf : calculate n_mult * gguf : inference with 7B model working (WIP) * gguf : rm deprecated function * gguf : start implementing gguf_file_saver (WIP) * gguf : start implementing gguf_file_saver (WIP) * gguf : start implementing gguf_file_saver (WIP) * gguf : add gguf_get_kv_type * gguf : add gguf_get_kv_type * gguf : write metadata in gguf_file_saver (WIP) * gguf : write metadata in gguf_file_saver (WIP) * gguf : write metadata in gguf_file_saver * gguf : rm references to old file formats * gguf : shorter name for member variable * gguf : rm redundant method * gguf : get rid of n_mult, read n_ff from file * Update gguf_tensor_map.py * Update gptneox-main.cpp * gguf : rm references to old file magics * gguf : start implementing quantization (WIP) * gguf : start implementing quantization (WIP) * gguf : start implementing quantization (WIP) * gguf : start implementing quantization (WIP) * gguf : start implementing quantization (WIP) * gguf : start implementing quantization (WIP) * gguf : quantization is working * gguf : roper closing of file * gguf.py : no need to convert tensors twice * convert-gptneox-h5-to-gguf.py : no need to convert tensors twice * convert-llama-h5-to-gguf.py : no need to convert tensors twice * convert-gptneox-h5-to-gguf.py : simplify nbytes * convert-llama-h5-to-gguf.py : simplify nbytes * gptneox-main.cpp : n_layer --> n_block * constants.py : n_layer --> n_block * gguf.py : n_layer --> n_block * convert-gptneox-h5-to-gguf.py : n_layer --> n_block * convert-llama-h5-to-gguf.py : n_layer --> n_block * gptneox-main.cpp : n_layer --> n_block * Update gguf_tensor_map.py * convert-gptneox-h5-to-gguf.py : load model in parts to save memory * convert-llama-h5-to-gguf.py : load model in parts to save memory * convert : write more metadata for LLaMA * convert : rm quantization version * convert-gptneox-h5-to-gguf.py : add file_type key * gptneox-main.cpp : add file_type key * fix conflicts * gguf : add todos and comments * convert-gptneox-h5-to-gguf.py : tensor name map changes * Create gguf_namemap.py : tensor name map changes * Delete gguf_tensor_map.py * gptneox-main.cpp : tensor name map changes * convert-llama-h5-to-gguf.py : fixes * gguf.py : dont add empty strings * simple : minor style changes * gguf : use UNIX line ending * Create convert-llama-7b-pth-to-gguf.py * llama : sync gguf-llama.cpp with latest llama.cpp (#2608) * llama : sync gguf-llama.cpp with latest llama.cpp * minor : indentation + assert * llama : refactor gguf_buffer and gguf_ctx_buffer * llama : minor * gitignore : add gptneox-main * llama : tokenizer fixes (#2549) * Merge tokenizer fixes into the gguf branch. * Add test vocabularies * convert : update convert-new.py with tokenizer fixes (#2614) * Merge tokenizer fixes into the gguf branch. * Add test vocabularies * Adapt convert-new.py (and fix a clang-cl compiler error on windows) * llama : sync gguf-llama with llama (#2613) * llama : sync gguf-llama with llama * tests : fix build + warnings (test-tokenizer-1 still fails) * tests : fix wstring_convert * convert : fix layer names * llama : sync gguf-llama.cpp * convert : update HF converter to new tokenizer voodoo magics * llama : update tokenizer style * convert-llama-h5-to-gguf.py : add token types * constants.py : add token types * gguf.py : add token types * convert-llama-7b-pth-to-gguf.py : add token types * gguf-llama.cpp : fix n_head_kv * convert-llama-h5-to-gguf.py : add 70b gqa support * gguf.py : add tensor data layout * convert-llama-h5-to-gguf.py : add tensor data layout * convert-llama-7b-pth-to-gguf.py : add tensor data layout * gptneox-main.cpp : add tensor data layout * convert-llama-h5-to-gguf.py : clarify the reverse permute * llama : refactor model loading code (#2620) * llama : style formatting + remove helper methods * llama : fix quantization using gguf tool * llama : simplify gguf_file_saver * llama : fix method names * llama : simplify write_header() * llama : no need to pass full file loader to the file saver just gguf_ctx * llama : gguf_file_saver write I32 * llama : refactor tensor names (#2622) * gguf: update tensor names searched in quantization * gguf : define tensor names as constants * gguf : initial write API (not tested yet) * gguf : write to file API (not tested) * gguf : initial write API ready + example * gguf : fix header write * gguf : fixes + simplify example + add ggml_nbytes_pad() * gguf : minor * llama : replace gguf_file_saver with new gguf write API * gguf : streaming support when writing files * gguf : remove oboslete write methods * gguf : remove obosolete gguf_get_arr_xxx API * llama : simplify gguf_file_loader * llama : move hparams and vocab from gguf_file_loader to llama_model_loader * llama : merge gguf-util.h in llama.cpp * llama : reorder definitions in .cpp to match .h * llama : minor simplifications * llama : refactor llama_model_loader (WIP) wip : remove ggml_ctx from llama_model_loader wip : merge gguf_file_loader in llama_model_loader * llama : fix shape prints * llama : fix Windows build + fix norm_rms_eps key * llama : throw error on missing KV paris in model meta data * llama : improve printing + log meta data * llama : switch print order of meta data --------- Co-authored-by: M. Yusuf Sarıgöz <yusufsarigoz@gmail.com> * gguf : deduplicate (#2629) * gguf : better type names * dedup : CPU + Metal is working * ggml : fix warnings about unused results * llama.cpp : fix line feed and compiler warning * llama : fix strncpy warning + note token_to_str does not write null * llama : restore the original load/save session implementation Will migrate this to GGUF in the future * convert-llama-h5-to-gguf.py : support alt ctx param name * ggml : assert when using ggml_mul with non-F32 src1 * examples : dedup simple --------- Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> * gguf.py : merge all files in gguf.py * convert-new.py : pick #2427 for HF 70B support * examples/gguf : no need to keep q option for quantization any more * llama.cpp : print actual model size * llama.cpp : use ggml_elements() * convert-new.py : output gguf (#2635) * convert-new.py : output gguf (WIP) * convert-new.py : add gguf key-value pairs * llama : add hparams.ctx_train + no longer print ftype * convert-new.py : minor fixes * convert-new.py : vocab-only option should work now * llama : fix tokenizer to use llama_char_to_byte * tests : add new ggml-vocab-llama.gguf * convert-new.py : tensor name mapping * convert-new.py : add map for skipping tensor serialization * convert-new.py : convert script now works * gguf.py : pick some of the refactoring from #2644 * convert-new.py : minor fixes * convert.py : update to support GGUF output * Revert "ci : disable CI temporary to not waste energy" This reverts commit 7e82d25f40386540c2c15226300ad998ecd871ea. * convert.py : n_head_kv optional and .gguf file extension * convert.py : better always have n_head_kv and default it to n_head * llama : sync with recent PRs on master * editorconfig : ignore models folder ggml-ci * ci : update ".bin" to ".gguf" extension ggml-ci * llama : fix llama_model_loader memory leak * gptneox : move as a WIP example * llama : fix lambda capture ggml-ci * ggml : fix bug in gguf_set_kv ggml-ci * common.h : .bin --> .gguf * quantize-stats.cpp : .bin --> .gguf * convert.py : fix HF tensor permuting / unpacking ggml-ci * llama.cpp : typo * llama : throw error if gguf fails to init from file ggml-ci * llama : fix tensor name grepping during quantization ggml-ci * gguf.py : write tensors in a single pass (#2644) * gguf : single pass for writing tensors + refactoring writer * gguf : single pass for writing tensors + refactoring writer * gguf : single pass for writing tensors + refactoring writer * gguf : style fixes in simple conversion script * gguf : refactor gptneox conversion script * gguf : rename h5 to hf (for HuggingFace) * gguf : refactor pth to gguf conversion script * gguf : rm file_type key and method * gguf.py : fix vertical alignment * gguf.py : indentation --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * convert-gptneox-hf-to-gguf.py : fixes * gguf.py : gptneox mapping * convert-llama-hf-to-gguf.py : fixes * convert-llama-7b-pth-to-gguf.py : fixes * ggml.h : reverse GGUF_MAGIC * gguf.py : reverse GGUF_MAGIC * test-tokenizer-0.cpp : fix warning * llama.cpp : print kv general.name * llama.cpp : get special token kv and linefeed token id * llama : print number of tensors per type + print arch + style * tests : update vocab file with new magic * editorconfig : fix whitespaces * llama : re-order functions * llama : remove C++ API + reorganize common source in /common dir * llama : minor API updates * llama : avoid hardcoded special tokens * llama : fix MPI build ggml-ci * llama : introduce enum llama_vocab_type + remove hardcoded string constants * convert-falcon-hf-to-gguf.py : falcon HF --> gguf conversion, not tested * falcon-main.cpp : falcon inference example * convert-falcon-hf-to-gguf.py : remove extra kv * convert-gptneox-hf-to-gguf.py : remove extra kv * convert-llama-7b-pth-to-gguf.py : remove extra kv * convert-llama-hf-to-gguf.py : remove extra kv * gguf.py : fix for falcon 40b * falcon-main.cpp : fix for falcon 40b * convert-falcon-hf-to-gguf.py : update ref * convert-falcon-hf-to-gguf.py : add tensor data layout * cmpnct_gpt2bpe.hpp : fixes * falcon-main.cpp : fixes * gptneox-main.cpp : fixes * cmpnct_gpt2bpe.hpp : remove non-general stuff * Update examples/server/README.md Co-authored-by: slaren <slarengh@gmail.com> * cmpnct_gpt2bpe.hpp : cleanup * convert-llama-hf-to-gguf.py : special tokens * convert-llama-7b-pth-to-gguf.py : special tokens * convert-permute-debug.py : permute debug print * convert-permute-debug-master.py : permute debug for master * convert-permute-debug.py : change permute type of attn_q * convert.py : 70b model working (change attn_q permute) * Delete convert-permute-debug-master.py * Delete convert-permute-debug.py * convert-llama-hf-to-gguf.py : fix attn_q permute * gguf.py : fix rope scale kv * convert-llama-hf-to-gguf.py : rope scale and added tokens * convert-llama-7b-pth-to-gguf.py : rope scale and added tokens * llama.cpp : use rope scale kv * convert-llama-7b-pth-to-gguf.py : rope scale fix * convert-llama-hf-to-gguf.py : rope scale fix * py : fix whitespace * gguf : add Python script to convert GGMLv3 LLaMA models to GGUF (#2682) * First pass at converting GGMLv3 LLaMA models to GGUF * Cleanups, better output during conversion * Fix vocab space conversion logic * More vocab conversion fixes * Add description to converted GGUF files * Improve help text, expand warning * Allow specifying name and description for output GGUF * Allow overriding vocab and hyperparams from original model metadata * Use correct params override var name * Fix wrong type size for Q8_K Better handling of original style metadata * Set default value for gguf add_tensor raw_shape KW arg * llama : improve token type support (#2668) * Merge tokenizer fixes into the gguf branch. * Add test vocabularies * Adapt convert-new.py (and fix a clang-cl compiler error on windows) * Improved tokenizer test But does it work on MacOS? * Improve token type support - Added @klosax code to convert.py - Improved token type support in vocabulary * Exclude platform dependent tests * More sentencepiece compatibility by eliminating magic numbers * Restored accidentally removed comment * llama : add API for token type ggml-ci * tests : use new tokenizer type API (#2692) * Merge tokenizer fixes into the gguf branch. * Add test vocabularies * Adapt convert-new.py (and fix a clang-cl compiler error on windows) * Improved tokenizer test But does it work on MacOS? * Improve token type support - Added @klosax code to convert.py - Improved token type support in vocabulary * Exclude platform dependent tests * More sentencepiece compatibility by eliminating magic numbers * Restored accidentally removed comment * Improve commentary * Use token type API in test-tokenizer-1.cpp * py : cosmetics * readme : add notice about new file format ggml-ci --------- Co-authored-by: M. Yusuf Sarıgöz <yusufsarigoz@gmail.com> Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> Co-authored-by: goerch <jhr.walter@t-online.de> Co-authored-by: slaren <slarengh@gmail.com> Co-authored-by: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com>
2023-08-21 22:07:43 +02:00
tokens[batch_start] = llama_token_bos(ctx);
}
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;
const auto batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
if (total_seconds >= 60*60) {
fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
}
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}
// We get the logits for all the tokens in the context window (params.n_ctx)
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
// calculate the perplexity over the last half of the window (so the model always has
// some context to predict the token).
//
// We rely on the fact that attention in the forward pass only looks at previous
// tokens here, so the logits returned for each token are an accurate representation
// of what the model would have predicted at that point.
//
// Example, we have a context window of 512, we will compute perplexity for each of the
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
const int first = params.n_ctx/2;
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += params.n_ctx - first - 1;
// perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) {
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
} else {
double av = nll/count;
double av2 = nll2/count - av*av;
if (av2 > 0) av2 = sqrt(av2/(count-1));
printf("%8d %.4lf %4lf %4lf\n", i*params.n_ctx, std::exp(nll / count), av, av2);
}
fflush(stdout);
}
printf("\n");
nll2 /= count;
nll /= count;
const double ppl = exp(nll);
nll2 -= nll * nll;
if (nll2 > 0) {
nll2 = sqrt(nll2/(count-1));
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
} else {
printf("Unexpected negative standard deviation of log(prob)\n");
}
return {tokens, ppl, logit_history, prob_history};
}
static std::vector<float> hellaswag_evaluate_tokens(
llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, int n_vocab, int n_thread
) {
std::vector<float> result;
result.reserve(tokens.size() * n_vocab);
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
size_t n_tokens = tokens.size() - i_chunk * n_batch;
n_tokens = std::min(n_tokens, size_t(n_batch));
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {};
}
const auto logits = llama_get_logits(ctx);
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
n_past += n_tokens;
}
return result;
}
static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
// Calculates hellaswag score (acc_norm) from prompt
//
// Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
// All used data fields are preprocessed as in https://github.com/EleutherAI/lm-evaluation-harness/blob/df3da98c5405deafd519c2ddca52bb7c3fe36bef/lm_eval/tasks/hellaswag.py#L62-L68
//
// All 10042 tasks should be extracted to keep the results standardized like other implementations.
//
// Datafile layout:
// ['??'] denotes json fields
// 6 lines per task:
// ['activity_label'] + ": " +['ctx'] - The first part of the query, the context
// ['label'] - The index the best common sense ending aka gold ending
// ['endings'][0] - Endings added to the first part of the query
// ['endings'][1]
// ['endings'][2]
// ['endings'][3]
std::vector<std::string> prompt_lines;
std::istringstream strstream(params.prompt);
std::string line;
while (std::getline(strstream,line,'\n')) {
prompt_lines.push_back(line);
}
if( prompt_lines.size() % 6 != 0) {
fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__);
return;
}
size_t hs_task_count = prompt_lines.size()/6;
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
llm : add Falcon support (#2717) * llama : refactor GGUF constants into static maps * llama : check if model architecture is known * llama : refactor llama_model_load_internal() * gguf : add KV constant maps * llm : read arch-specific KVs * convert : add dummy scores + types * falcon : load tensor data (CPU only) * llama : fix loading progress bar * llama : add arch member to llama_model * falcon : CPU inference working * falcon : support non-40B models * falcon : minor * llama : minor updates ggml-ci * convert-falcon-hf-to-gguf.py : fix special token mapping * llama.cpp : llama default UNK token = id 0 * llama.cpp : fix bpe tokenizer * llama.cpp : fix the fix of bpe tokenizer * ggml : pass eps to ggml_norm * metal : implement RoPE (mode = 2) + avoid ggml_repeat * ggml : ggml_repeat always creates new tensor * falcon : copy-paste self-attention from LLaMA * metal : print extra compute pipeline info * falcon : minor changes (still chasing the Metal problem) * llama.cpp : fix linefeed token * metal : fix GELU kernel numerical stability by using precise::tanh * metal : temporary workaround for the concurrency optimization bug * falcon : add CUDA offloading (#2739) * llama : better model naming and size reporting * llama : prep new tokenizer support * llama : advanced BPE tokenizer based on ggllm.cpp imlpementation * llama : remove oboslete comment ggml-ci * common : remove obsolete BPE API + disable test-tokenizer-1 * llama : revert BPE special-case in llama_byte_to_token() * cuda : add TODOs for RoPE NeoX implementation * llama : default special tokens based on vocab type * perplexity : add log for start of tokenization --------- Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> Co-authored-by: slaren <slarengh@gmail.com>
2023-08-23 22:08:04 +02:00
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
fprintf(stderr, "================================= is_spm = %d\n", is_spm);
llm : add Falcon support (#2717) * llama : refactor GGUF constants into static maps * llama : check if model architecture is known * llama : refactor llama_model_load_internal() * gguf : add KV constant maps * llm : read arch-specific KVs * convert : add dummy scores + types * falcon : load tensor data (CPU only) * llama : fix loading progress bar * llama : add arch member to llama_model * falcon : CPU inference working * falcon : support non-40B models * falcon : minor * llama : minor updates ggml-ci * convert-falcon-hf-to-gguf.py : fix special token mapping * llama.cpp : llama default UNK token = id 0 * llama.cpp : fix bpe tokenizer * llama.cpp : fix the fix of bpe tokenizer * ggml : pass eps to ggml_norm * metal : implement RoPE (mode = 2) + avoid ggml_repeat * ggml : ggml_repeat always creates new tensor * falcon : copy-paste self-attention from LLaMA * metal : print extra compute pipeline info * falcon : minor changes (still chasing the Metal problem) * llama.cpp : fix linefeed token * metal : fix GELU kernel numerical stability by using precise::tanh * metal : temporary workaround for the concurrency optimization bug * falcon : add CUDA offloading (#2739) * llama : better model naming and size reporting * llama : prep new tokenizer support * llama : advanced BPE tokenizer based on ggllm.cpp imlpementation * llama : remove oboslete comment ggml-ci * common : remove obsolete BPE API + disable test-tokenizer-1 * llama : revert BPE special-case in llama_byte_to_token() * cuda : add TODOs for RoPE NeoX implementation * llama : default special tokens based on vocab type * perplexity : add log for start of tokenization --------- Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> Co-authored-by: slaren <slarengh@gmail.com>
2023-08-23 22:08:04 +02:00
// This is needed as usual for LLaMA models
llm : add Falcon support (#2717) * llama : refactor GGUF constants into static maps * llama : check if model architecture is known * llama : refactor llama_model_load_internal() * gguf : add KV constant maps * llm : read arch-specific KVs * convert : add dummy scores + types * falcon : load tensor data (CPU only) * llama : fix loading progress bar * llama : add arch member to llama_model * falcon : CPU inference working * falcon : support non-40B models * falcon : minor * llama : minor updates ggml-ci * convert-falcon-hf-to-gguf.py : fix special token mapping * llama.cpp : llama default UNK token = id 0 * llama.cpp : fix bpe tokenizer * llama.cpp : fix the fix of bpe tokenizer * ggml : pass eps to ggml_norm * metal : implement RoPE (mode = 2) + avoid ggml_repeat * ggml : ggml_repeat always creates new tensor * falcon : copy-paste self-attention from LLaMA * metal : print extra compute pipeline info * falcon : minor changes (still chasing the Metal problem) * llama.cpp : fix linefeed token * metal : fix GELU kernel numerical stability by using precise::tanh * metal : temporary workaround for the concurrency optimization bug * falcon : add CUDA offloading (#2739) * llama : better model naming and size reporting * llama : prep new tokenizer support * llama : advanced BPE tokenizer based on ggllm.cpp imlpementation * llama : remove oboslete comment ggml-ci * common : remove obsolete BPE API + disable test-tokenizer-1 * llama : revert BPE special-case in llama_byte_to_token() * cuda : add TODOs for RoPE NeoX implementation * llama : default special tokens based on vocab type * perplexity : add log for start of tokenization --------- Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> Co-authored-by: slaren <slarengh@gmail.com>
2023-08-23 22:08:04 +02:00
const bool add_bos = is_spm;
// Number of tasks to use when computing the score
if ( params.hellaswag_tasks < hs_task_count ) {
hs_task_count = params.hellaswag_tasks;
}
// The tasks should be randomized so the score stabilizes quickly.
bool randomize_tasks = true;
// The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now
std::mt19937 rng(1);
// Dataholder for hellaswag tasks
struct hs_data_t {
std::string context;
size_t gold_ending_idx;
std::string ending[4];
size_t ending_logprob_count[4];
double ending_logprob[4];
};
fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") );
// Select and read data from prompt lines
hs_data_t *hs_data = new hs_data_t[hs_task_count];
for (size_t i=0; i < hs_task_count; i++) {
size_t idx = i;
// Select a random example of those left in the prompt
if (randomize_tasks) {
std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
idx = dist(rng);
}
hs_data[i].context = prompt_lines[idx*6];
hs_data[i].gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
for (size_t j=0; j < 4; j++) {
hs_data[i].ending[j] = prompt_lines[idx*6+2+j];
}
// Delete the selected random example from the prompt
if (randomize_tasks) {
prompt_lines.erase( std::next(prompt_lines.begin(),idx*6) , std::next(prompt_lines.begin(),idx*6+6) );
}
}
fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__);
printf("\ntask\tacc_norm\n");
double acc = 0.0f;
const int n_vocab = llama_n_vocab(ctx);
std::vector<std::vector<int>> ending_tokens(4);
std::vector<float> tok_logits(n_vocab);
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
// Tokenize the context to count tokens
llm : add Falcon support (#2717) * llama : refactor GGUF constants into static maps * llama : check if model architecture is known * llama : refactor llama_model_load_internal() * gguf : add KV constant maps * llm : read arch-specific KVs * convert : add dummy scores + types * falcon : load tensor data (CPU only) * llama : fix loading progress bar * llama : add arch member to llama_model * falcon : CPU inference working * falcon : support non-40B models * falcon : minor * llama : minor updates ggml-ci * convert-falcon-hf-to-gguf.py : fix special token mapping * llama.cpp : llama default UNK token = id 0 * llama.cpp : fix bpe tokenizer * llama.cpp : fix the fix of bpe tokenizer * ggml : pass eps to ggml_norm * metal : implement RoPE (mode = 2) + avoid ggml_repeat * ggml : ggml_repeat always creates new tensor * falcon : copy-paste self-attention from LLaMA * metal : print extra compute pipeline info * falcon : minor changes (still chasing the Metal problem) * llama.cpp : fix linefeed token * metal : fix GELU kernel numerical stability by using precise::tanh * metal : temporary workaround for the concurrency optimization bug * falcon : add CUDA offloading (#2739) * llama : better model naming and size reporting * llama : prep new tokenizer support * llama : advanced BPE tokenizer based on ggllm.cpp imlpementation * llama : remove oboslete comment ggml-ci * common : remove obsolete BPE API + disable test-tokenizer-1 * llama : revert BPE special-case in llama_byte_to_token() * cuda : add TODOs for RoPE NeoX implementation * llama : default special tokens based on vocab type * perplexity : add log for start of tokenization --------- Co-authored-by: klosax <131523366+klosax@users.noreply.github.com> Co-authored-by: slaren <slarengh@gmail.com>
2023-08-23 22:08:04 +02:00
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
size_t context_size = context_embd.size();
for (int i = 0; i < 4; ++i) {
ending_tokens[i] = ::llama_tokenize(ctx, hs_data[task_idx].context + " " + hs_data[task_idx].ending[i], add_bos);
for (int k = 0; k < int(context_size); ++k) {
if (ending_tokens[i][k] != context_embd[k]) {
fprintf(stderr, "Oops: ending %d of task %d differs from context at position %d\n",i,int(task_idx),k);
break;
}
}
}
// Do the 1st ending
// In this case we include the context when evaluating
//auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
auto query_embd = ending_tokens[0];
auto query_size = query_embd.size();
// Stop if query wont fit the ctx window
if (query_size > (size_t)params.n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}
// Speedup small evaluations by evaluating atleast 32 tokens
if (query_size < 32) {
query_embd.resize(32);
}
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}
std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float));
const auto first_probs = softmax(tok_logits);
hs_data[task_idx].ending_logprob_count[0] = 1;
hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);
// Calculate the logprobs over the ending
for (size_t j = context_size; j < query_size - 1; j++) {
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[query_embd[j + 1]];
hs_data[task_idx].ending_logprob[0] += std::log(prob);
hs_data[task_idx].ending_logprob_count[0]++;
}
// Calculate the mean token logprob for acc_norm
hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0];
// Do the remaining endings
// For these, we use the bare ending with n_past = context_size
//
for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
// Tokenize the query
query_embd.resize(ending_tokens[ending_idx].size() - context_size);
std::memcpy(query_embd.data(), ending_tokens[ending_idx].data() + context_size, query_embd.size()*sizeof(int));
query_size = query_embd.size();
// Stop if query wont fit the ctx window
if (context_size + query_size > (size_t)params.n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}
// Speedup small evaluations by evaluating atleast 32 tokens
// No, resizing to 32 is actually slightly slower (at least on CUDA)
//if (query_size < 32) {
// query_embd.resize(32);
//}
// Evaluate the query
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}
hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
// Calculate the logprobs over the ending
for (size_t j = 0; j < query_size - 1; j++) {
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[query_embd[j + 1]];
hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
hs_data[task_idx].ending_logprob_count[ending_idx]++;
}
// Calculate the mean token logprob for acc_norm
hs_data[task_idx].ending_logprob[ending_idx] /= hs_data[task_idx].ending_logprob_count[ending_idx];
// printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n",
// task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] );
}
// Find the ending with maximum logprob
size_t ending_logprob_max_idx = 0;
double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0];
for (size_t j = 1; j < 4; j++) {
if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
ending_logprob_max_idx = j;
ending_logprob_max_val = hs_data[task_idx].ending_logprob[j];
}
}
// printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx);
// If the gold ending got the maximum logprobe add one accuracy point
if (ending_logprob_max_idx == hs_data[task_idx].gold_ending_idx) {
acc += 1.0;
}
// Print the accumulated accuracy mean x 100
printf("%zu\t%.8lf\n",task_idx+1, acc/double(task_idx+1)*100.0);
fflush(stdout);
}
delete [] hs_data;
printf("\n");
}
int main(int argc, char ** argv) {
gpt_params params;
params.n_batch = 512;
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
params.perplexity = true;
params.n_batch = std::min(params.n_batch, params.n_ctx);
if (params.ppl_stride > 0) {
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
params.n_ctx, params.n_ctx + params.ppl_stride/2);
params.n_ctx += params.ppl_stride/2;
}
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}
fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
std::mt19937 rng(params.seed);
if (params.random_prompt) {
params.prompt = gpt_random_prompt(rng);
}
llama_backend_init(params.numa);
llama_model * model;
llama_context * ctx;
// load the model and apply lora adapter, if any
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
2023-04-17 17:28:55 +02:00
}
const int n_ctx_train = llama_n_ctx_train(ctx);
if (params.n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
}
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
struct results_perplexity results;
if (params.hellaswag) {
hellaswag_score(ctx, params);
} else {
results = perplexity(ctx, params);
}
llama_print_timings(ctx);
write_logfile(ctx, params, model, results);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
return 0;
}