From 18705a30ef3d6a89e1d7c6cb8cfe8633f760cb53 Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Fri, 1 Sep 2023 05:03:49 -0400 Subject: [PATCH] llama2c : fix segfault and alloc-dealloc-mismatch (#2913) * llama2c : fix segfault if vocab is not found * llama2c : fix mismatch between new[] and delete * llama2c : fix basename on Windows * llama2c : use a destructor to prevent memory leaks --- .../convert-llama2c-to-ggml.cpp | 43 ++++++++++--------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index e9e070b1f..0b03c9d2b 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -75,7 +75,7 @@ typedef struct { int seq_len; // max sequence length } Config; -typedef struct { +struct TransformerWeights { // token embedding table float* token_embedding_table; // (vocab_size, dim) // weights for rmsnorms @@ -97,7 +97,22 @@ typedef struct { // float* freq_cis_imag; // (seq_len, dim/2) // (optional) classifier weights for the logits, on the last layer float* wcls; -} TransformerWeights; + + ~TransformerWeights() { + delete[] token_embedding_table; + delete[] rms_att_weight; + delete[] rms_ffn_weight; + delete[] wq; + delete[] wk; + delete[] wv; + delete[] wo; + delete[] w1; + delete[] w2; + delete[] w3; + delete[] rms_final_weight; + delete[] wcls; + } +}; void malloc_weights(TransformerWeights* w, Config* p, bool shared_weights) { // we calloc instead of malloc to keep valgrind happy @@ -173,21 +188,6 @@ int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bool shar return 0; } -void free_weights(TransformerWeights* w) { - delete w->token_embedding_table; - delete w->rms_att_weight; - delete w->rms_ffn_weight; - delete w->wq; - delete w->wk; - delete w->wv; - delete w->wo; - delete w->w1; - delete w->w2; - delete w->w3; - delete w->rms_final_weight; - if (w->wcls) delete w->wcls; -} - void print_sample_weights(TransformerWeights *w){ printf("----- Quick print of first of the weight vales of all the variables\n"); printf("%f\n", w->token_embedding_table[0]); @@ -596,6 +596,10 @@ void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) // assume llama2.c vocabulary printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename); llama_file file(filename, "rb"); + if (!file.fp) { + fprintf(stderr, "error: %s: %s\n", strerror(errno), filename); + exit(1); + } const int n_vocab = config->vocab_size; /* uint32_t max_token_length = */ file.read_u32(); // unused vocab->id_to_token.resize(n_vocab); @@ -898,7 +902,7 @@ bool params_parse(int argc, char ** argv, struct train_params * params) { } std::string basename(const std::string &path) { - size_t pos = path.find_last_of("/"); + size_t pos = path.find_last_of("/\\"); if (pos == std::string::npos) { return path; } @@ -911,7 +915,7 @@ int main(int argc, char ** argv) { return 1; } Config config; - TransformerWeights weights; + TransformerWeights weights = {}; { FILE *file = fopen(params.fn_llama2c_model, "rb"); if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; } @@ -953,6 +957,5 @@ int main(int argc, char ** argv) { printf("Saving llama.c model file %s in ggml format at %s\n", params.fn_llama2c_model, params.fn_llama2c_output_model); ggml_free(model.ctx); - free_weights(&weights); return 0; }