From 7a9b6c3a8bdc1cb75fefc826dfaa7331eb63695d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 24 Mar 2023 23:17:37 +0200 Subject: [PATCH] Reduce memory usage and allocate enough memory for largest context (#473) * Reduce memory usage and allocate enough memory for large contexts * Simpler scratch buffer usage * Reenable BLAS for quantized mul_mat * Fix number of layers in 30B and 65B * Fix KV cache size for F32 --- ggml.c | 12 +- llama.cpp | 326 ++++++++++++++++++++++++++++++++++++++++++++---------- main.cpp | 23 +++- utils.cpp | 10 +- utils.h | 16 +-- 5 files changed, 307 insertions(+), 80 deletions(-) diff --git a/ggml.c b/ggml.c index 92b857a00..cfdf427df 100644 --- a/ggml.c +++ b/ggml.c @@ -5846,7 +5846,8 @@ static bool ggml_compute_forward_mul_mat_use_blas( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - UNUSED(src0); + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; const int ne10 = src1->ne[0]; @@ -5856,7 +5857,14 @@ static bool ggml_compute_forward_mul_mat_use_blas( // TODO: find the optimal values for these if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) { - //printf("BLAS: %d %d %d\n", ne0, ne1, ne10); + + //// disable BLAS for Q4_0 and Q4_1 + //// looks like there is no benefit and we only waste a lot of memory + //if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1) { + // return false; + //} + + //printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01); return true; } diff --git a/llama.cpp b/llama.cpp index 9a93409cc..9d48ccd4c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5,12 +5,25 @@ #include #include #include +#include #include #include #include #include #include +#define LLAMA_USE_SCRATCH +#define LLAMA_MAX_SCRATCH_BUFFERS 16 + +#define LLAMA_ASSERT(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + + // determine number of model parts based on the dimension static const std::unordered_map LLAMA_N_PARTS = { { 4096, 1 }, @@ -19,6 +32,52 @@ static const std::unordered_map LLAMA_N_PARTS = { { 8192, 8 }, }; +// available llama models +enum e_model { + MODEL_UNKNOWN, + MODEL_7B, + MODEL_13B, + MODEL_30B, + MODEL_65B, +}; + +static const size_t MB = 1024*1024; + +// computed for n_ctx == 2048 +// TODO: dynamically determine these sizes +// needs modifications in ggml + +static const std::map MEM_REQ_SCRATCH0 = { + { MODEL_7B, 512ull*MB }, + { MODEL_13B, 512ull*MB }, + { MODEL_30B, 512ull*MB }, + { MODEL_65B, 512ull*MB }, +}; + +static const std::map MEM_REQ_SCRATCH1 = { + { MODEL_7B, 512ull*MB }, + { MODEL_13B, 512ull*MB }, + { MODEL_30B, 512ull*MB }, + { MODEL_65B, 512ull*MB }, +}; + +// 2*n_embd*n_ctx*n_layer*sizeof(float16) +static const std::map MEM_REQ_KV_SELF = { + { MODEL_7B, 1026ull*MB }, + { MODEL_13B, 1608ull*MB }, + { MODEL_30B, 3124ull*MB }, + { MODEL_65B, 5120ull*MB }, +}; + +// this is mostly needed for temporary mul_mat buffers to dequantize the data +// not actually needed if BLAS is disabled +static const std::map MEM_REQ_EVAL = { + { MODEL_7B, 768ull*MB }, + { MODEL_13B, 1024ull*MB }, + { MODEL_30B, 1280ull*MB }, + { MODEL_65B, 1536ull*MB }, +}; + // default hparams (LLaMA 7B) struct llama_hparams { int32_t n_vocab = 32000; @@ -50,7 +109,20 @@ struct llama_layer { struct ggml_tensor * w3; }; +struct llama_kv_cache { + struct ggml_tensor * k; + struct ggml_tensor * v; + + struct ggml_context * ctx; + + std::vector buf; + + int n; // number of tokens currently in the cache +}; + struct llama_model { + e_model type = MODEL_UNKNOWN; + llama_hparams hparams; struct ggml_tensor * tok_embeddings; @@ -60,12 +132,18 @@ struct llama_model { std::vector layers; - // key + value memory - struct ggml_tensor * memory_k; - struct ggml_tensor * memory_v; - - // + // context struct ggml_context * ctx; + + // key + value cache for the self attention + // TODO: move to llama_state + struct llama_kv_cache kv_self; + + // the model memory buffer + std::vector buf; + + // tensors + int n_loaded; std::unordered_map tensors; }; @@ -105,8 +183,88 @@ struct llama_context { // input embedding (1-dimensional array: [n_embd]) std::vector embedding; + + // memory buffers used to evaluate the model + // TODO: move in llama_state + std::vector buf_compute; + std::vector buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; + + int buf_last = 0; + size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; + + void use_buf(struct ggml_context * ctx, int i) { +#if defined(LLAMA_USE_SCRATCH) + size_t last_size = 0; + + if (i == -1) { + last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, }); + } else { + auto & buf = buf_scratch[i]; + last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), }); + } + + if (buf_last >= 0) { + buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); + } + + buf_last = i; +#else + (void) i; + (void) ctx; +#endif + } + + size_t get_buf_max_mem(int i) const { +#if defined(LLAMA_USE_SCRATCH) + return buf_max_size[i]; +#else + (void) i; + return 0; +#endif + } }; +// +// kv cache +// + +static bool kv_cache_init( + const struct llama_hparams & hparams, + struct llama_kv_cache & cache, + ggml_type wtype, + int n_ctx) { + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + + const int n_mem = n_layer*n_ctx; + const int n_elements = n_embd*n_mem; + + cache.buf.resize(2*n_elements*ggml_type_size(wtype) + 2u*MB); + + struct ggml_init_params params; + params.mem_size = cache.buf.size(); + params.mem_buffer = cache.buf.data(); + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + +static void kv_cache_free(struct llama_kv_cache & cache) { + if (cache.ctx) { + ggml_free(cache.ctx); + cache.ctx = nullptr; + } +} + struct llama_context_params llama_context_default_params() { struct llama_context_params result = { /*.n_ctx =*/ 512, @@ -204,6 +362,22 @@ static bool llama_model_load( fprintf(stderr, "%s: use '--n_parts 1' if necessary\n", __func__); } + if (hparams.n_layer == 32) { + model.type = e_model::MODEL_7B; + } + + if (hparams.n_layer == 40) { + model.type = e_model::MODEL_13B; + } + + if (hparams.n_layer == 60) { + model.type = e_model::MODEL_30B; + } + + if (hparams.n_layer == 80) { + model.type = e_model::MODEL_65B; + } + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_ctx = %d\n", __func__, hparams.n_ctx); fprintf(stderr, "%s: n_embd = %d\n", __func__, hparams.n_embd); @@ -214,6 +388,7 @@ static bool llama_model_load( fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); fprintf(stderr, "%s: n_ff = %d\n", __func__, n_ff); fprintf(stderr, "%s: n_parts = %d\n", __func__, n_parts); + fprintf(stderr, "%s: type = %d\n", __func__, model.type); } // load vocab @@ -307,11 +482,32 @@ static bool llama_model_load( fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); } + // print memory requirements + { + const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; + + // this is the total memory required to run the inference + const size_t mem_required = + ctx_size + + MEM_REQ_SCRATCH0.at(model.type) + + MEM_REQ_SCRATCH1.at(model.type) + + MEM_REQ_EVAL.at (model.type); + + // this is the memory required by one llama_state + const size_t mem_required_state = + scale*MEM_REQ_KV_SELF.at(model.type); + + fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, + mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); + } + // create the ggml context { + lctx.model.buf.resize(ctx_size); + struct ggml_init_params params = { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, + /*.mem_size =*/ lctx.model.buf.size(), + /*.mem_buffer =*/ lctx.model.buf.data(), }; model.ctx = ggml_init(params); @@ -374,25 +570,6 @@ static bool llama_model_load( } } - // key + value memory - { - const auto & hparams = model.hparams; - - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; - - const int n_mem = n_layer*n_ctx; - const int n_elements = n_embd*n_mem; - - model.memory_k = ggml_new_tensor_1d(ctx, memory_type, n_elements); - model.memory_v = ggml_new_tensor_1d(ctx, memory_type, n_elements); - - const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); - - fprintf(stderr, "%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem); - } - const size_t file_offset = fin.tellg(); fin.close(); @@ -416,9 +593,10 @@ static bool llama_model_load( // load weights { - int n_tensors = 0; size_t total_size = 0; + model.n_loaded = 0; + fprintf(stderr, "%s: ", __func__); while (true) { @@ -583,7 +761,10 @@ static bool llama_model_load( } //fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0); - if (++n_tensors % 8 == 0) { + model.n_loaded++; + + // progress + if (model.n_loaded % 8 == 0) { fprintf(stderr, "."); fflush(stderr); } @@ -591,7 +772,13 @@ static bool llama_model_load( fprintf(stderr, " done\n"); - fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors); + fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded); + if (model.n_loaded == 0) { + fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (model.n_loaded != (int) model.tensors.size()) { + fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + return false; + } } fin.close(); @@ -622,6 +809,10 @@ static bool llama_eval_internal( const auto & model = lctx.model; const auto & hparams = model.hparams; + auto & kv_self = model.kv_self; + + LLAMA_ASSERT(!!kv_self.ctx); + const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; const int n_ctx = hparams.n_ctx; @@ -630,27 +821,11 @@ static bool llama_eval_internal( const int n_rot = hparams.n_embd/hparams.n_head; auto & mem_per_token = lctx.mem_per_token; - - // TODO: fix this hardcoded size - static size_t buf_size = 2048u*1024*1024; // TMP !!! - static void * buf = malloc(buf_size); - - if (mem_per_token > 0 && mem_per_token*N > buf_size) { - const size_t buf_size_new = 1.3*(mem_per_token*N); // add 30% to account for ggml object overhead - //fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); - - // reallocate - buf_size = buf_size_new; - buf = realloc(buf, buf_size); - if (buf == nullptr) { - fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); - return false; - } - } + auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf, + /*.mem_size =*/ buf_compute.size(), + /*.mem_buffer =*/ buf_compute.data(), }; struct ggml_context * ctx0 = ggml_init(params); @@ -667,6 +842,8 @@ static bool llama_eval_internal( struct ggml_tensor * cur; + lctx.use_buf(ctx0, 0); + // norm { cur = ggml_rms_norm(ctx0, inpL); @@ -685,8 +862,8 @@ static bool llama_eval_internal( // store key and value to memory if (N >= 1) { - struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_embd, (ggml_element_size(kv_self.v)*n_embd)*(il*n_ctx + n_past)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); @@ -707,7 +884,7 @@ static bool llama_eval_internal( ggml_permute(ctx0, ggml_rope(ctx0, ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd), + ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), n_embd/n_head, n_head, n_past + N), n_past, n_rot, 1), 0, 2, 1, 3); @@ -733,7 +910,7 @@ static bool llama_eval_internal( ggml_cpy(ctx0, ggml_permute(ctx0, ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), + ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd), n_embd/n_head, n_head, n_past + N), 1, 2, 0, 3), ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head)); @@ -755,6 +932,8 @@ static bool llama_eval_internal( cur); } + lctx.use_buf(ctx0, 1); + struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); // feed-forward network @@ -773,7 +952,6 @@ static bool llama_eval_internal( model.layers[il].w3, cur); - cur = ggml_mul_mat(ctx0, model.layers[il].w1, cur); @@ -788,17 +966,20 @@ static bool llama_eval_internal( cur); } - cur = ggml_add(ctx0, cur, inpFF); + cur = ggml_add(ctx0, cur, inpFF); // input for next layer inpL = cur; } + lctx.use_buf(ctx0, 0); + // used at the end to optionally extract the embeddings struct ggml_tensor * embeddings = NULL; // norm { + inpL = ggml_rms_norm(ctx0, inpL); // inpL = norm*inpL @@ -810,9 +991,9 @@ static bool llama_eval_internal( } // lm_head - { - inpL = ggml_mul_mat(ctx0, model.output, inpL); - } + inpL = ggml_mul_mat(ctx0, model.output, inpL); + + lctx.use_buf(ctx0, -1); // logits -> probs //inpL = ggml_soft_max(ctx0, inpL); @@ -854,7 +1035,13 @@ static bool llama_eval_internal( if (mem_per_token == 0) { mem_per_token = ggml_used_mem(ctx0)/N; } - //fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0)); + +#if 0 + printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__, + ggml_used_mem(ctx0)/1024.0/1024.0, + lctx.get_buf_max_mem(0)/1024.0/1024.0, + lctx.get_buf_max_mem(1)/1024.0/1024.0); +#endif ggml_free(ctx0); @@ -1427,9 +1614,9 @@ struct llama_context * llama_init_from_file( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; - ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; + ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, + if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type, params.vocab_only)) { fprintf(stderr, "%s: failed to load model\n", __func__); llama_free(ctx); @@ -1448,6 +1635,17 @@ struct llama_context * llama_init_from_file( // reserve memory for context buffers { + if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); + llama_free(ctx); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v); + fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + const auto & hparams = ctx->model.hparams; if (params.logits_all) { ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); @@ -1458,12 +1656,19 @@ struct llama_context * llama_init_from_file( if (params.embedding){ ctx->embedding.reserve(hparams.n_embd); } + + ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type)); + + ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type)); + ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type)); } return ctx; } void llama_free(struct llama_context * ctx) { + kv_cache_free(ctx->model.kv_self); + if (ctx->model.ctx) { ggml_free(ctx->model.ctx); } @@ -1619,4 +1824,3 @@ const char * llama_print_system_info(void) { return s.c_str(); } - diff --git a/main.cpp b/main.cpp index 44437750e..bc71a5494 100644 --- a/main.cpp +++ b/main.cpp @@ -217,11 +217,23 @@ int main(int argc, char ** argv) { params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } - // determine the required inference memory per token: - // TODO: better way to do that - { - const std::vector tmp = { 0, 1, 2, 3 }; - llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); + // determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters + // uncomment the "used_mem" line in llama.cpp to see the results + if (params.mem_test) { + { + const std::vector tmp(params.n_batch, 0); + llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); + } + + { + const std::vector tmp = { 0, }; + llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads); + } + + llama_print_timings(ctx); + llama_free(ctx); + + return 0; } if (params.perplexity) { @@ -508,7 +520,6 @@ int main(int argc, char ** argv) { #endif llama_print_timings(ctx); - llama_free(ctx); set_console_state(CONSOLE_STATE_DEFAULT); diff --git a/utils.cpp b/utils.cpp index 10673fb82..2f995c12d 100644 --- a/utils.cpp +++ b/utils.cpp @@ -79,8 +79,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.n_ctx = std::stoi(argv[i]); - } else if (arg == "--memory_f16") { - params.memory_f16 = true; + } else if (arg == "--memory_f32") { + params.memory_f16 = false; } else if (arg == "--top_p") { if (++i >= argc) { invalid_param = true; @@ -111,6 +111,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.n_batch = std::stoi(argv[i]); + params.n_batch = std::min(512, params.n_batch); } else if (arg == "-m" || arg == "--model") { if (++i >= argc) { invalid_param = true; @@ -131,6 +132,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.use_color = true; } else if (arg == "--mlock") { params.use_mlock = true; + } else if (arg == "--mtest") { + params.mem_test = true; } else if (arg == "-r" || arg == "--reverse-prompt") { if (++i >= argc) { invalid_param = true; @@ -193,7 +196,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty); fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating\n"); - fprintf(stderr, " --memory_f16 use f16 instead of f32 for memory key+value\n"); + fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n"); fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n"); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); @@ -201,6 +204,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { if (ggml_mlock_supported()) { fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n"); } + fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, "\n"); diff --git a/utils.h b/utils.h index cf914990c..d469bc6a0 100644 --- a/utils.h +++ b/utils.h @@ -14,12 +14,13 @@ // struct gpt_params { - int32_t seed = -1; // RNG seed + int32_t seed = -1; // RNG seed int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); - int32_t n_predict = 128; // new tokens to predict - int32_t repeat_last_n = 64; // last n tokens to penalize - int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) - int32_t n_ctx = 512; //context size + int32_t n_predict = 128; // new tokens to predict + int32_t repeat_last_n = 64; // last n tokens to penalize + int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) + int32_t n_ctx = 512; // context size + int32_t n_batch = 8; // batch size for prompt processing // sampling parameters int32_t top_k = 40; @@ -27,15 +28,13 @@ struct gpt_params { float temp = 0.80f; float repeat_penalty = 1.10f; - int32_t n_batch = 8; // batch size for prompt processing - std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; std::vector antiprompt; // string upon seeing which more user input is prompted - bool memory_f16 = false; // use f16 instead of f32 for memory kv + bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode @@ -47,6 +46,7 @@ struct gpt_params { bool ignore_eos = false; // do not stop generating after eos bool perplexity = false; // compute perplexity over the prompt bool use_mlock = false; // use mlock to keep model in memory + bool mem_test = false; // compute maximum memory usage }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params);