From 16bc66d9479edd5ee12ec734973554d4493c5dfa Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 28 Sep 2023 21:42:38 +0200 Subject: [PATCH] llama.cpp : split llama_context_params into model and context params (#3301) * llama.cpp : split llama_context_params into model and context params ggml-ci * fix metal build * fix freq_base/scale default to model value * llama-bench : keep the same model between tests when possible * move n_threads to llama_context_params, add n_threads_batch * fix mpi build * remove kv_size(), cuda scratch fixes * remove low-vram option * add n_threads_batch to system info, refactor to get_system_info() * add documentation about --threads-batch to the READMEs * llama-bench fix * main : fix rope freq/scale warning * llama.cpp : add llama_get_model common : add llama_tokenize from model * remove duplicated ctx/model functions ggml-ci * cuda : print total VRAM used --- common/common.cpp | 110 ++-- common/common.h | 12 +- common/train.cpp | 10 +- examples/batched/batched.cpp | 39 +- examples/beam-search/beam-search.cpp | 4 +- examples/embd-input/embd-input-lib.cpp | 13 +- examples/embd-input/embd-input-test.cpp | 2 +- examples/embedding/embedding.cpp | 21 +- examples/finetune/finetune.cpp | 12 +- examples/llama-bench/llama-bench.cpp | 159 +++-- examples/main/README.md | 4 +- examples/main/main.cpp | 41 +- examples/parallel/parallel.cpp | 6 +- examples/perplexity/perplexity.cpp | 73 +-- examples/quantize-stats/quantize-stats.cpp | 17 +- examples/save-load-state/save-load-state.cpp | 26 +- examples/server/README.md | 4 +- examples/server/server.cpp | 50 +- examples/simple/simple.cpp | 24 +- examples/speculative/speculative.cpp | 16 +- .../train-text-from-scratch.cpp | 12 +- ggml-cuda.cu | 24 +- llama.cpp | 545 ++++++++---------- llama.h | 84 ++- tests/test-tokenizer-0-falcon.cpp | 12 +- tests/test-tokenizer-0-llama.cpp | 12 +- tests/test-tokenizer-1-llama.cpp | 14 +- 27 files changed, 713 insertions(+), 633 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 8764a7be3..6e8c08cb8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -129,6 +129,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { if (params.n_threads <= 0) { params.n_threads = std::thread::hardware_concurrency(); } + } else if (arg == "-tb" || arg == "--threads-batch") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_threads_batch = std::stoi(argv[i]); + if (params.n_threads_batch <= 0) { + params.n_threads_batch = std::thread::hardware_concurrency(); + } } else if (arg == "-p" || arg == "--prompt") { if (++i >= argc) { invalid_param = true; @@ -451,12 +460,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.mul_mat_q = false; #else fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n"); -#endif // GGML_USE_CUBLAS - } else if (arg == "--low-vram" || arg == "-lv") { -#ifdef GGML_USE_CUBLAS - params.low_vram = true; -#else - fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n"); #endif // GGML_USE_CUBLAS } else if (arg == "--no-mmap") { params.use_mmap = false; @@ -630,7 +633,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" (can be specified more than once for multiple prompts).\n"); printf(" --color colorise output to distinguish prompt and user input from generations\n"); printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); - printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads); + printf(" -tb N, --threads-batch N\n"); + printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n"); printf(" -p PROMPT, --prompt PROMPT\n"); printf(" prompt to start generation with (default: empty)\n"); printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); @@ -645,7 +650,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -f FNAME, --file FNAME\n"); printf(" prompt file to start generation.\n"); printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); - printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); + printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); @@ -705,7 +710,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -ts SPLIT --tensor-split SPLIT\n"); printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n"); - printf(" -lv, --low-vram don't allocate VRAM scratch buffer\n"); #ifdef GGML_USE_CUBLAS printf(" -nommq, --no-mul-mat-q\n"); printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n"); @@ -726,6 +730,18 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf("\n"); } +std::string get_system_info(const gpt_params & params) { + std::ostringstream os; + + os << "system_info: n_threads = " << params.n_threads; + if (params.n_threads_batch != -1) { + os << " (n_threads_batch = " << params.n_threads_batch << ")"; + } + os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info(); + + return os.str(); +} + std::string gpt_random_prompt(std::mt19937 & rng) { const int r = rng() % 10; switch (r) { @@ -749,40 +765,50 @@ std::string gpt_random_prompt(std::mt19937 & rng) { // Model utils // -struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { - auto lparams = llama_context_default_params(); +struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) { + auto mparams = llama_model_default_params(); - lparams.n_ctx = params.n_ctx; - lparams.n_batch = params.n_batch; if (params.n_gpu_layers != -1) { - lparams.n_gpu_layers = params.n_gpu_layers; + mparams.n_gpu_layers = params.n_gpu_layers; } - lparams.main_gpu = params.main_gpu; - lparams.tensor_split = params.tensor_split; - lparams.low_vram = params.low_vram; - lparams.mul_mat_q = params.mul_mat_q; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - lparams.logits_all = params.logits_all; - lparams.embedding = params.embedding; - lparams.rope_freq_base = params.rope_freq_base; - lparams.rope_freq_scale = params.rope_freq_scale; + mparams.main_gpu = params.main_gpu; + mparams.tensor_split = params.tensor_split; + mparams.use_mmap = params.use_mmap; + mparams.use_mlock = params.use_mlock; - return lparams; + return mparams; +} + +struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { + auto cparams = llama_context_default_params(); + + cparams.n_ctx = params.n_ctx; + cparams.n_batch = params.n_batch; + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + cparams.mul_mat_q = params.mul_mat_q; + cparams.seed = params.seed; + cparams.f16_kv = params.memory_f16; + cparams.logits_all = params.logits_all; + cparams.embedding = params.embedding; + cparams.rope_freq_base = params.rope_freq_base; + cparams.rope_freq_scale = params.rope_freq_scale; + + return cparams; } std::tuple llama_init_from_gpt_params(gpt_params & params) { - auto lparams = llama_context_params_from_gpt_params(params); + auto mparams = llama_model_params_from_gpt_params(params); - llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); + llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); return std::make_tuple(nullptr, nullptr); } - llama_context * lctx = llama_new_context_with_model(model, lparams); + auto cparams = llama_context_params_from_gpt_params(params); + + llama_context * lctx = llama_new_context_with_model(model, cparams); if (lctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); llama_free_model(model); @@ -815,7 +841,7 @@ std::tuple llama_init_from_gpt_par LOG("warming up the model with an empty run\n"); std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; - llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads); + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_kv_cache_tokens_rm(lctx, -1, -1); llama_reset_timings(lctx); } @@ -828,16 +854,23 @@ std::tuple llama_init_from_gpt_par // std::vector llama_tokenize( - struct llama_context * ctx, + const struct llama_context * ctx, + const std::string & text, + bool add_bos) { + return llama_tokenize(llama_get_model(ctx), text, add_bos); +} + +std::vector llama_tokenize( + const struct llama_model * model, const std::string & text, bool add_bos) { // upper limit for the number of tokens int n_tokens = text.length() + add_bos; std::vector result(n_tokens); - n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos); + n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos); + int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); @@ -847,10 +880,10 @@ std::vector llama_tokenize( std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) { std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size()); + const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_token_to_piece(ctx, token, result.data(), result.size()); + int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); @@ -905,7 +938,7 @@ llama_token llama_sample_token( std::vector & candidates, int idx) { const int n_ctx = llama_n_ctx(ctx); - const int n_vocab = llama_n_vocab(ctx); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; @@ -1191,7 +1224,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l #endif // NDEBUG fprintf(stream, "model_desc: %s\n", model_desc); - fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(lctx)); + fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx))); #ifdef __OPTIMIZE__ fprintf(stream, "optimize: true\n"); @@ -1258,7 +1291,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la)); } fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); - fprintf(stream, "low_vram: %s # default: false\n", params.low_vram ? "true" : "false"); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false"); fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat); diff --git a/common/common.h b/common/common.h index 64601f997..0e2d3fa6c 100644 --- a/common/common.h +++ b/common/common.h @@ -36,6 +36,7 @@ int32_t get_num_physical_cores(); struct gpt_params { uint32_t seed = -1; // RNG seed int32_t n_threads = get_num_physical_cores(); + int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) @@ -95,7 +96,6 @@ struct gpt_params { bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score - bool low_vram = false; // if true, reduce VRAM usage at the cost of performance bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided @@ -126,6 +126,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params); void gpt_print_usage(int argc, char ** argv, const gpt_params & params); +std::string get_system_info(const gpt_params & params); + std::string gpt_random_prompt(std::mt19937 & rng); void process_escapes(std::string& input); @@ -135,6 +137,7 @@ void process_escapes(std::string& input); // std::tuple llama_init_from_gpt_params(gpt_params & params); +struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); // @@ -144,7 +147,12 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param // tokenizes a string into a vector of tokens // should work similar to Python's `tokenizer.encode` std::vector llama_tokenize( - struct llama_context * ctx, + const struct llama_context * ctx, + const std::string & text, + bool add_bos); + +std::vector llama_tokenize( + const struct llama_model * model, const std::string & text, bool add_bos); diff --git a/common/train.cpp b/common/train.cpp index 4a1280966..35a4cf9e6 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -858,7 +858,7 @@ size_t tokenize_file( out_tokens.resize(buf.size() + n_max_tokens_overhead); int n_tokens = llama_tokenize( - lctx, + llama_get_model(lctx), buf.data(), (int) buf.size(), out_tokens.data(), @@ -867,7 +867,7 @@ size_t tokenize_file( if (n_tokens < 0) { out_tokens.resize(-n_tokens); n_tokens = llama_tokenize( - lctx, + llama_get_model(lctx), buf.data(), (int) buf.size(), out_tokens.data(), @@ -920,7 +920,7 @@ size_t tokenize_file( size_t found_max_sample_size = 0; size_t max_token_text_size = 0; - int n_vocab = llama_n_vocab(lctx); + int n_vocab = llama_n_vocab(llama_get_model(lctx)); for (llama_token token=0; token < n_vocab; ++token) { max_token_text_size = std::max( max_token_text_size, @@ -961,7 +961,7 @@ size_t tokenize_file( // tokenize the sample tok_sample.resize(buf_sample.size() + n_max_tokens_overhead); - int n_tokens = llama_tokenize(lctx, + int n_tokens = llama_tokenize(llama_get_model(lctx), buf_sample.data(), (int) buf_sample.size(), tok_sample.data(), @@ -969,7 +969,7 @@ size_t tokenize_file( false); if (n_tokens < 0) { tok_sample.resize(-n_tokens); - n_tokens = llama_tokenize(lctx, + n_tokens = llama_tokenize(llama_get_model(lctx), buf_sample.data(), (int) buf_sample.size(), tok_sample.data(), diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 4dd1d553d..688ef2213 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -40,20 +40,35 @@ int main(int argc, char ** argv) { llama_backend_init(params.numa); - llama_context_params ctx_params = llama_context_default_params(); + // initialize the model - ctx_params.seed = 1234; - ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301) - ctx_params.n_batch = std::max(n_len, n_parallel); - // ctx_params.n_gpu_layers = 99; // offload all layers to the GPU + llama_model_params model_params = llama_model_default_params(); - llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); + // model_params.n_gpu_layers = 99; // offload all layers to the GPU + + llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); return 1; } + // tokenize the prompt + + std::vector tokens_list; + tokens_list = ::llama_tokenize(model, params.prompt, true); + const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel; + + // initialize the context + + llama_context_params ctx_params = llama_context_default_params(); + + ctx_params.seed = 1234; + ctx_params.n_ctx = n_kv_req; + ctx_params.n_batch = std::max(n_len, n_parallel); + ctx_params.n_threads = params.n_threads; + ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + llama_context * ctx = llama_new_context_with_model(model, ctx_params); if (ctx == NULL) { @@ -61,13 +76,7 @@ int main(int argc, char ** argv) { return 1; } - // tokenize the prompt - - std::vector tokens_list; - tokens_list = ::llama_tokenize(ctx, params.prompt, true); - const int n_ctx = llama_n_ctx(ctx); - const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel; LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req); @@ -106,7 +115,7 @@ int main(int argc, char ** argv) { // llama_decode will output logits only for the last token of the prompt batch.logits[batch.n_tokens - 1] = true; - if (llama_decode(ctx, batch, params.n_threads) != 0) { + if (llama_decode(ctx, batch) != 0) { LOG_TEE("%s: llama_decode() failed\n", __func__); return 1; } @@ -146,7 +155,7 @@ int main(int argc, char ** argv) { continue; } - auto n_vocab = llama_n_vocab(ctx); + auto n_vocab = llama_n_vocab(model); auto * logits = llama_get_logits_ith(ctx, i_batch[i]); std::vector candidates; @@ -210,7 +219,7 @@ int main(int argc, char ** argv) { n_cur += 1; // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch, params.n_threads)) { + if (llama_decode(ctx, batch)) { fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); return 1; } diff --git a/examples/beam-search/beam-search.cpp b/examples/beam-search/beam-search.cpp index 63da7c3ec..f078ab8a8 100644 --- a/examples/beam-search/beam-search.cpp +++ b/examples/beam-search/beam-search.cpp @@ -160,7 +160,7 @@ int main(int argc, char ** argv) int n_past = 0; - if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0), params.n_threads)) + if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0))) { fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ ); return 1; @@ -170,7 +170,7 @@ int main(int argc, char ** argv) beam_search_callback_data callback_data{ctx, {}}; size_t const beam_width = static_cast(params.n_beams); int const n_predict = 256; - llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads); + llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict); std::cout << "\n\n"; for (llama_token const token_id : callback_data.response) { diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 9bd4d3470..99e6bdad5 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -48,8 +48,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) { // print system information { fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + fprintf(stderr, "%s\n", get_system_info(params).c_str()); } struct MyModel * ret = new MyModel(); ret->ctx = ctx; @@ -71,7 +70,7 @@ bool eval_float(void * model, float * input, int N){ MyModel * mymodel = (MyModel*)model; llama_context * ctx = mymodel->ctx; gpt_params params = mymodel->params; - int n_emb = llama_n_embd(ctx); + int n_emb = llama_n_embd(llama_get_model(ctx)); int n_past = mymodel->n_past; int n_batch = N; // params.n_batch; @@ -81,7 +80,7 @@ bool eval_float(void * model, float * input, int N){ n_eval = n_batch; } llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, }; - if (llama_decode(ctx, batch, params.n_threads)) { + if (llama_decode(ctx, batch)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } @@ -102,7 +101,7 @@ bool eval_tokens(void * model, std::vector tokens) { if (n_eval > params.n_batch) { n_eval = params.n_batch; } - if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } @@ -133,7 +132,7 @@ llama_token sampling_id(struct MyModel* mymodel) { // out of user input, sample next token const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; + const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : params.top_k; const float top_p = params.top_p; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; @@ -149,7 +148,7 @@ llama_token sampling_id(struct MyModel* mymodel) { llama_token id = 0; { auto logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(ctx); + auto n_vocab = llama_n_vocab(llama_get_model(ctx)); // Apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { diff --git a/examples/embd-input/embd-input-test.cpp b/examples/embd-input/embd-input-test.cpp index e5e040f62..dc4a0e488 100644 --- a/examples/embd-input/embd-input-test.cpp +++ b/examples/embd-input/embd-input-test.cpp @@ -8,7 +8,7 @@ int main(int argc, char** argv) { auto mymodel = create_mymodel(argc, argv); int N = 10; int max_tgt_len = 500; - int n_embd = llama_n_embd(mymodel->ctx); + int n_embd = llama_n_embd(llama_get_model(mymodel->ctx)); // add random float embd to test evaluation float * data = new float[N*n_embd]; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 18cefa237..14075609e 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -42,17 +42,18 @@ int main(int argc, char ** argv) { return 1; } - const int n_ctx_train = llama_n_ctx_train(ctx); - if (params.n_ctx > n_ctx_train) { + const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx = llama_n_ctx(ctx); + + if (n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", - __func__, n_ctx_train, params.n_ctx); + __func__, n_ctx_train, n_ctx); } // print system information { fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + fprintf(stderr, "%s\n", get_system_info(params).c_str()); } int n_past = 0; @@ -70,15 +71,15 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); } - if (embd_inp.size() > (size_t)params.n_ctx) { + if (embd_inp.size() > (size_t)n_ctx) { fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n", - __func__, embd_inp.size(), params.n_ctx); + __func__, embd_inp.size(), n_ctx); return 1; } while (!embd_inp.empty()) { int n_tokens = std::min(params.n_batch, (int) embd_inp.size()); - if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } @@ -86,8 +87,8 @@ int main(int argc, char ** argv) { embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens); } - const int n_embd = llama_n_embd(ctx); - const auto embeddings = llama_get_embeddings(ctx); + const int n_embd = llama_n_embd(model); + const auto * embeddings = llama_get_embeddings(ctx); for (int i = 0; i < n_embd; i++) { printf("%f ", embeddings[i]); diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 6e29e1c15..b61165fb7 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -304,7 +304,7 @@ static void init_model(struct llama_model * input, struct my_llama_model * model gguf_free(mctx); } - hparams.n_vocab = llama_model_n_vocab(input); + hparams.n_vocab = llama_n_vocab(input); hparams.n_ctx = n_ctx; // get tensors from llama_model (possibly mmapped) @@ -1540,12 +1540,14 @@ int main(int argc, char ** argv) { printf("%s: seed: %u\n", __func__, params.common.seed); srand(params.common.seed); - struct llama_context_params llama_params = llama_context_default_params(); - llama_params.vocab_only = false; + struct llama_model_params llama_mparams = llama_model_default_params(); + llama_mparams.vocab_only = false; printf("%s: model base = '%s'\n", __func__, params.fn_model_base); - struct llama_model * lmodel = llama_load_model_from_file(params.fn_model_base, llama_params); - struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params); + struct llama_model * lmodel = llama_load_model_from_file(params.fn_model_base, llama_mparams); + + struct llama_context_params llama_cparams = llama_context_default_params(); + struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_cparams); struct my_llama_model model; init_model(lmodel, &model, params.fn_model_base, params.common.n_ctx); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 058e34d5c..93bb0c8b1 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -132,7 +132,6 @@ struct cmd_params { std::vector n_gpu_layers; std::vector main_gpu; std::vector mul_mat_q; - std::vector low_vram; std::vector> tensor_split; int reps; bool verbose; @@ -149,7 +148,6 @@ static const cmd_params cmd_params_defaults = { /* n_gpu_layers */ {99}, /* main_gpu */ {0}, /* mul_mat_q */ {true}, - /* low_vram */ {false}, /* tensor_split */ {{}}, /* reps */ 5, /* verbose */ false, @@ -167,9 +165,8 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -b, --batch-size (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str()); printf(" --memory-f32 <0|1> (default: %s)\n", join(cmd_params_defaults.f32_kv, ",").c_str()); printf(" -t, --threads (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); - printf(" -ngl N, --n-gpu-layers (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); - printf(" -mg i, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); - printf(" -lv, --low-vram <0|1> (default: %s)\n", join(cmd_params_defaults.low_vram, ",").c_str()); + printf(" -ngl, --n-gpu-layers (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); + printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -mmq, --mul-mat-q <0|1> (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str()); printf(" -ts, --tensor_split \n"); printf(" -r, --repetitions (default: %d)\n", cmd_params_defaults.reps); @@ -255,13 +252,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.main_gpu = split(argv[i], split_delim); - } else if (arg == "-lv" || arg == "--low-vram") { - if (++i >= argc) { - invalid_param = true; - break; - } - auto p = split(argv[i], split_delim); - params.low_vram.insert(params.low_vram.end(), p.begin(), p.end()); } else if (arg == "-mmq" || arg == "--mul-mat-q") { if (++i >= argc) { invalid_param = true; @@ -336,7 +326,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; } if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.mul_mat_q.empty()) { params.mul_mat_q = cmd_params_defaults.mul_mat_q; } - if (params.low_vram.empty()) { params.low_vram = cmd_params_defaults.low_vram; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; } @@ -353,21 +342,34 @@ struct cmd_params_instance { int n_gpu_layers; int main_gpu; bool mul_mat_q; - bool low_vram; std::array tensor_split; - llama_context_params to_llama_params() const { - llama_context_params lparams = llama_context_default_params(); - lparams.n_ctx = n_prompt + n_gen; - lparams.n_batch = n_batch; - lparams.f16_kv = !f32_kv; - lparams.n_gpu_layers = n_gpu_layers; - lparams.main_gpu = main_gpu; - lparams.mul_mat_q = mul_mat_q; - lparams.low_vram = low_vram; - lparams.tensor_split = tensor_split.data(); + llama_model_params to_llama_mparams() const { + llama_model_params mparams = llama_model_default_params(); - return lparams; + mparams.n_gpu_layers = n_gpu_layers; + mparams.main_gpu = main_gpu; + mparams.tensor_split = tensor_split.data(); + + return mparams; + } + + bool equal_mparams(const cmd_params_instance & other) const { + return model == other.model && + n_gpu_layers == other.n_gpu_layers && + main_gpu == other.main_gpu && + tensor_split == other.tensor_split; + } + + llama_context_params to_llama_cparams() const { + llama_context_params cparams = llama_context_default_params(); + + cparams.n_ctx = n_prompt + n_gen; + cparams.n_batch = n_batch; + cparams.f16_kv = !f32_kv; + cparams.mul_mat_q = mul_mat_q; + + return cparams; } }; @@ -375,13 +377,12 @@ static std::vector get_cmd_params_instances_int(const cmd_p std::vector instances; for (const auto & m : params.model) - for (const auto & nb : params.n_batch) - for (const auto & fk : params.f32_kv) for (const auto & nl : params.n_gpu_layers) for (const auto & mg : params.main_gpu) - for (const auto & mmq : params.mul_mat_q) - for (const auto & lv : params.low_vram) for (const auto & ts : params.tensor_split) + for (const auto & nb : params.n_batch) + for (const auto & fk : params.f32_kv) + for (const auto & mmq : params.mul_mat_q) for (const auto & nt : params.n_threads) { cmd_params_instance instance = { /* .model = */ m, @@ -393,7 +394,6 @@ static std::vector get_cmd_params_instances_int(const cmd_p /* .n_gpu_layers = */ nl, /* .main_gpu = */ mg, /* .mul_mat_q = */ mmq, - /* .low_vram = */ lv, /* .tensor_split = */ ts, }; instances.push_back(instance); @@ -404,6 +404,56 @@ static std::vector get_cmd_params_instances_int(const cmd_p static std::vector get_cmd_params_instances(const cmd_params & params) { std::vector instances; +#if 1 + // this ordering minimizes the number of times that each model needs to be reloaded + for (const auto & m : params.model) + for (const auto & nl : params.n_gpu_layers) + for (const auto & mg : params.main_gpu) + for (const auto & ts : params.tensor_split) + for (const auto & nb : params.n_batch) + for (const auto & fk : params.f32_kv) + for (const auto & mmq : params.mul_mat_q) + for (const auto & nt : params.n_threads) { + for (const auto & n_prompt : params.n_prompt) { + if (n_prompt == 0) { + continue; + } + cmd_params_instance instance = { + /* .model = */ m, + /* .n_prompt = */ n_prompt, + /* .n_gen = */ 0, + /* .n_batch = */ nb, + /* .f32_kv = */ fk, + /* .n_threads = */ nt, + /* .n_gpu_layers = */ nl, + /* .main_gpu = */ mg, + /* .mul_mat_q = */ mmq, + /* .tensor_split = */ ts, + }; + instances.push_back(instance); + } + + for (const auto & n_gen : params.n_gen) { + if (n_gen == 0) { + continue; + } + cmd_params_instance instance = { + /* .model = */ m, + /* .n_prompt = */ 0, + /* .n_gen = */ n_gen, + /* .n_batch = */ nb, + /* .f32_kv = */ fk, + /* .n_threads = */ nt, + /* .n_gpu_layers = */ nl, + /* .main_gpu = */ mg, + /* .mul_mat_q = */ mmq, + /* .tensor_split = */ ts, + }; + instances.push_back(instance); + } + } +#else + // this ordering separates the prompt and generation tests for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { continue; @@ -419,6 +469,7 @@ static std::vector get_cmd_params_instances(const cmd_param auto instances_gen = get_cmd_params_instances_int(params, n_gen, 0); instances.insert(instances.end(), instances_gen.begin(), instances_gen.end()); } +#endif return instances; } @@ -443,7 +494,6 @@ struct test { int n_gpu_layers; int main_gpu; bool mul_mat_q; - bool low_vram; std::array tensor_split; int n_prompt; int n_gen; @@ -463,7 +513,6 @@ struct test { n_gpu_layers = inst.n_gpu_layers; main_gpu = inst.main_gpu; mul_mat_q = inst.mul_mat_q; - low_vram = inst.low_vram; tensor_split = inst.tensor_split; n_prompt = inst.n_prompt; n_gen = inst.n_gen; @@ -524,7 +573,7 @@ struct test { "cpu_info", "gpu_info", "model_filename", "model_type", "model_size", "model_n_params", "n_batch", "n_threads", "f16_kv", - "n_gpu_layers", "main_gpu", "mul_mat_q", "low_vram", "tensor_split", + "n_gpu_layers", "main_gpu", "mul_mat_q", "tensor_split", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts" @@ -543,7 +592,7 @@ struct test { return INT; } if (field == "cuda" || field == "opencl" || field == "metal" || field == "gpu_blas" || field == "blas" || - field == "f16_kv" || field == "mul_mat_q" || field == "low_vram") { + field == "f16_kv" || field == "mul_mat_q") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -574,7 +623,7 @@ struct test { cpu_info, gpu_info, model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params), std::to_string(n_batch), std::to_string(n_threads), std::to_string(!f32_kv), - std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), std::to_string(low_vram), tensor_split_str, + std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), tensor_split_str, std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), std::to_string(avg_ts()), std::to_string(stdev_ts()) @@ -766,9 +815,6 @@ struct markdown_printer : public printer { if (params.mul_mat_q.size() > 1 || params.mul_mat_q != cmd_params_defaults.mul_mat_q) { fields.push_back("mul_mat_q"); } - if (params.low_vram.size() > 1 || params.low_vram != cmd_params_defaults.low_vram) { - fields.push_back("low_vram"); - } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.push_back("tensor_split"); } @@ -889,17 +935,23 @@ struct sql_printer : public printer { static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { std::vector tokens(n_batch, llama_token_bos(ctx)); int n_processed = 0; + + llama_set_n_threads(ctx, n_threads, n_threads); + while (n_processed < n_prompt) { int n_tokens = std::min(n_prompt - n_processed, n_batch); - llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads); + llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0)); n_processed += n_tokens; } } static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { llama_token token = llama_token_bos(ctx); + + llama_set_n_threads(ctx, n_threads, n_threads); + for (int i = 0; i < n_gen; i++) { - llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads); + llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0)); } } @@ -958,17 +1010,25 @@ int main(int argc, char ** argv) { std::vector params_instances = get_cmd_params_instances(params); - for (const auto & inst : params_instances) { - // TODO: keep the model between tests when possible - llama_context_params lparams = inst.to_llama_params(); + llama_model * lmodel = nullptr; + const cmd_params_instance * prev_inst = nullptr; - llama_model * lmodel = llama_load_model_from_file(inst.model.c_str(), lparams); - if (lmodel == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, inst.model.c_str()); - return 1; + for (const auto & inst : params_instances) { + // keep the same model between tests when possible + if (!lmodel || !prev_inst || !inst.equal_mparams(*prev_inst)) { + if (lmodel) { + llama_free_model(lmodel); + } + + lmodel = llama_load_model_from_file(inst.model.c_str(), inst.to_llama_mparams()); + if (lmodel == NULL) { + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, inst.model.c_str()); + return 1; + } + prev_inst = &inst; } - llama_context * ctx = llama_new_context_with_model(lmodel, lparams); + llama_context * ctx = llama_new_context_with_model(lmodel, inst.to_llama_cparams()); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str()); llama_free_model(lmodel); @@ -1006,9 +1066,10 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); - llama_free_model(lmodel); } + llama_free_model(lmodel); + p->print_footer(); llama_backend_free(); diff --git a/examples/main/README.md b/examples/main/README.md index 26e1e28dd..a9561c383 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -262,7 +262,8 @@ These options help improve the performance and memory usage of the LLaMA models. ### Number of Threads -- `-t N, --threads N`: Set the number of threads to use during computation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). Using the correct number of threads can greatly improve performance. +- `-t N, --threads N`: Set the number of threads to use during generation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). Using the correct number of threads can greatly improve performance. +- `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. In some systems, it is beneficial to use a higher number of threads during batch processing than during generation. If not specified, the number of threads used for batch processing will be the same as the number of threads used for generation. ### Mlock @@ -305,6 +306,5 @@ These options provide extra functionality and customization when running the LLa - `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance. - `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS. - `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS. -- `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS. - `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains. - `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation. diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1ed543cbc..fd506773f 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -140,12 +140,17 @@ int main(int argc, char ** argv) { return 0; } - if (params.rope_freq_base != 10000.0) { - LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base); + if (params.n_ctx != 0 && params.n_ctx < 8) { + LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); + params.n_ctx = 8; } - if (params.rope_freq_scale != 1.0) { - LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale); + if (params.rope_freq_base != 0.0) { + LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base); + } + + if (params.rope_freq_scale != 0.0) { + LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); } LOG_TEE("%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); @@ -184,20 +189,19 @@ int main(int argc, char ** argv) { return 1; } - const int n_ctx_train = llama_n_ctx_train(ctx); - if (params.n_ctx > n_ctx_train) { + const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx = llama_n_ctx(ctx); + LOG("n_ctx: %d\n", n_ctx); + + if (n_ctx > n_ctx_train) { LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n", - __func__, n_ctx_train, params.n_ctx); - } else if (params.n_ctx < 8) { - LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); - params.n_ctx = 8; + __func__, n_ctx_train, n_ctx); } // print system information { LOG_TEE("\n"); - LOG_TEE("system_info: n_threads = %d / %d | %s\n", - params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + LOG_TEE("%s\n", get_system_info(params).c_str()); } std::string path_session = params.path_prompt_cache; @@ -211,7 +215,7 @@ int main(int argc, char ** argv) { if (fp != NULL) { std::fclose(fp); - session_tokens.resize(params.n_ctx); + session_tokens.resize(n_ctx); size_t n_token_count_out = 0; if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str()); @@ -226,7 +230,7 @@ int main(int argc, char ** argv) { } } - const bool add_bos = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; + const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM; LOG("add_bos: %d\n", add_bos); std::vector embd_inp; @@ -267,9 +271,6 @@ int main(int argc, char ** argv) { LOG("guidance_offset: %s", log_tostr(guidance_offset)); } - const int n_ctx = llama_n_ctx(ctx); - LOG("n_ctx: %d\n", n_ctx); - if ((int) embd_inp.size() > n_ctx - 4) { LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); return 1; @@ -466,7 +467,7 @@ int main(int argc, char ** argv) { std::vector embd; std::vector embd_guidance; - const int n_vocab = llama_n_vocab(ctx); + const int n_vocab = llama_n_vocab(model); std::vector candidates; candidates.reserve(n_vocab); @@ -576,7 +577,7 @@ int main(int argc, char ** argv) { for (int i = 0; i < input_size; i += params.n_batch) { int n_eval = std::min(input_size - i, params.n_batch); - if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0), params.n_threads)) { + if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) { LOG_TEE("%s : failed to eval\n", __func__); return 1; } @@ -593,7 +594,7 @@ int main(int argc, char ** argv) { LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0), params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { LOG_TEE("%s : failed to eval\n", __func__); return 1; } diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 790189af9..0434ded23 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -108,7 +108,7 @@ int main(int argc, char ** argv) { fflush(stderr); const int n_ctx = llama_n_ctx(ctx); - const int n_vocab = llama_n_vocab(ctx); + const int n_vocab = llama_n_vocab(model); std::vector clients(n_clients); for (size_t i = 0; i < clients.size(); ++i) { @@ -153,7 +153,7 @@ int main(int argc, char ** argv) { batch.logits[i] = false; } - if (llama_decode(ctx, batch, params.n_threads) != 0) { + if (llama_decode(ctx, batch) != 0) { LOG_TEE("%s: llama_decode() failed\n", __func__); return 1; } @@ -272,7 +272,7 @@ int main(int argc, char ** argv) { 0, 0, 0, // unused }; - const int ret = llama_decode(ctx, batch_view, params.n_threads); + const int ret = llama_decode(ctx, batch_view); if (ret != 0) { if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index de08bd4a1..7d0038bd4 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -150,16 +150,18 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval - const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; + const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = is_spm; fprintf(stderr, "%s: tokenizing the input ..\n", __func__); std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); - if (int(tokens.size()) < 2*params.n_ctx) { - fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx, - params.n_ctx); + const int n_ctx = llama_n_ctx(ctx); + + if (int(tokens.size()) < 2*n_ctx) { + fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx, + n_ctx); fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); return {std::move(tokens), 0., {}, {}}; } @@ -175,20 +177,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & return {tokens, -1, logit_history, prob_history}; } - const int calc_chunk = params.n_ctx; + const int calc_chunk = n_ctx; fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk); if (int(tokens.size()) <= calc_chunk) { fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__, - tokens.size(), params.n_ctx, params.ppl_stride); + tokens.size(), n_ctx, params.ppl_stride); return {tokens, -1, logit_history, prob_history}; } const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride; const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); - const int n_vocab = llama_n_vocab(ctx); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_batch = params.n_batch; int count = 0; @@ -215,7 +217,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const int batch_size = std::min(end - batch_start, n_batch); //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { //fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } @@ -250,7 +252,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & } //fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start); - for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) { + for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) { // Calculate probability of next token, given the previous ones. const std::vector tok_logits( @@ -287,8 +289,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval - const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; + const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = is_spm; + const int n_ctx = llama_n_ctx(ctx); auto tim1 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenizing the input ..\n", __func__); @@ -298,9 +301,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par auto tim2 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); - if (int(tokens.size()) < 2*params.n_ctx) { - fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx, - params.n_ctx); + if (int(tokens.size()) < 2*n_ctx) { + fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx, + n_ctx); fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); return {std::move(tokens), 0., {}, {}}; } @@ -311,10 +314,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par std::vector prob_history; prob_history.resize(tokens.size()); - const int n_chunk_max = tokens.size() / params.n_ctx; + const int n_chunk_max = tokens.size() / n_ctx; const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); - const int n_vocab = llama_n_vocab(ctx); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_batch = params.n_batch; int count = 0; @@ -326,10 +329,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par std::vector workers(std::thread::hardware_concurrency() - 1); for (int i = 0; i < n_chunk; ++i) { - const int start = i * params.n_ctx; - const int end = start + params.n_ctx; + const int start = i * n_ctx; + const int end = start + n_ctx; - const int num_batches = (params.n_ctx + n_batch - 1) / n_batch; + const int num_batches = (n_ctx + n_batch - 1) / n_batch; std::vector logits; @@ -350,7 +353,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par tokens[batch_start] = llama_token_bos(ctx); } - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } @@ -358,7 +361,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // restore the original token in case it was set to BOS tokens[batch_start] = token_org; - const auto batch_logits = llama_get_logits(ctx); + const auto * batch_logits = llama_get_logits(ctx); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); } @@ -387,10 +390,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // Example, we have a context window of 512, we will compute perplexity for each of the // last 256 tokens. Then, we split the input up into context window size chunks to // process the entire prompt. - const int first = params.n_ctx/2; - process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first, + const int first = n_ctx/2; + process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); - count += params.n_ctx - first - 1; + count += n_ctx - first - 1; // perplexity is e^(average negative log-likelihood) if (params.ppl_output_type == 0) { @@ -399,7 +402,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par double av = nll/count; double av2 = nll2/count - av*av; if (av2 > 0) av2 = sqrt(av2/(count-1)); - printf("%8d %.4lf %4lf %4lf\n", i*params.n_ctx, std::exp(nll / count), av, av2); + printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); } fflush(stdout); } @@ -420,7 +423,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } static std::vector hellaswag_evaluate_tokens( - llama_context * ctx, std::vector & tokens, int n_past, int n_batch, int n_vocab, int n_thread + llama_context * ctx, std::vector & tokens, int n_past, int n_batch, int n_vocab ) { std::vector result; result.reserve(tokens.size() * n_vocab); @@ -428,7 +431,7 @@ static std::vector hellaswag_evaluate_tokens( for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { size_t n_tokens = tokens.size() - i_chunk * n_batch; n_tokens = std::min(n_tokens, size_t(n_batch)); - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return {}; } @@ -475,7 +478,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { size_t hs_task_count = prompt_lines.size()/6; fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count); - const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; + const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM; fprintf(stderr, "================================= is_spm = %d\n", is_spm); // This is needed as usual for LLaMA models @@ -530,7 +533,8 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { printf("\ntask\tacc_norm\n"); double acc = 0.0f; - const int n_vocab = llama_n_vocab(ctx); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const int n_ctx = llama_n_ctx(ctx); std::vector> ending_tokens(4); @@ -558,7 +562,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { auto query_size = query_embd.size(); // Stop if query wont fit the ctx window - if (query_size > (size_t)params.n_ctx) { + if (query_size > (size_t)n_ctx) { fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size); return; } @@ -571,7 +575,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { // clear the KV cache llama_kv_cache_tokens_rm(ctx, -1, -1); - auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads); + auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab); if (logits.empty()) { fprintf(stderr, "%s : failed to eval\n", __func__); return; @@ -608,7 +612,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { query_size = query_embd.size(); // Stop if query wont fit the ctx window - if (context_size + query_size > (size_t)params.n_ctx) { + if (context_size + query_size > (size_t)n_ctx) { fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size); return; } @@ -620,7 +624,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { //} // Evaluate the query - logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads); + logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab); if (logits.empty()) { fprintf(stderr, "%s : failed to eval\n", __func__); return; @@ -716,7 +720,7 @@ int main(int argc, char ** argv) { return 1; } - const int n_ctx_train = llama_n_ctx_train(ctx); + const int n_ctx_train = llama_n_ctx_train(model); if (params.n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, params.n_ctx); @@ -725,8 +729,7 @@ int main(int argc, char ** argv) { // print system information { fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + fprintf(stderr, "%s\n", get_system_info(params).c_str()); } struct results_perplexity results; diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 94edb94d9..dd76b1cee 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -309,21 +309,22 @@ int main(int argc, char ** argv) { llama_context * ctx; { - auto lparams = llama_context_default_params(); + auto mparams = llama_model_default_params(); + mparams.use_mlock = false; - lparams.n_ctx = 256; - lparams.seed = 1; - lparams.f16_kv = false; - lparams.use_mlock = false; - - model = llama_load_model_from_file(params.model.c_str(), lparams); + model = llama_load_model_from_file(params.model.c_str(), mparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); return 1; } - ctx = llama_new_context_with_model(model, lparams); + auto cparams = llama_context_default_params(); + cparams.n_ctx = 256; + cparams.seed = 1; + cparams.f16_kv = false; + + ctx = llama_new_context_with_model(model, cparams); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 6e4d40b9e..acc6dbdfd 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -23,23 +23,17 @@ int main(int argc, char ** argv) { params.n_predict = 16; } - auto lparams = llama_context_default_params(); - - lparams.n_ctx = params.n_ctx; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - auto n_past = 0; auto last_n_tokens_data = std::vector(params.repeat_last_n, 0); // init - auto * model = llama_load_model_from_file(params.model.c_str(), lparams); + llama_model * model; + llama_context * ctx; + + std::tie(model, ctx) = llama_init_from_gpt_params( params ); if (model == nullptr) { return 1; } - auto * ctx = llama_new_context_with_model(model, lparams); if (ctx == nullptr) { llama_free_model(model); return 1; @@ -54,7 +48,7 @@ int main(int argc, char ** argv) { } // evaluate prompt - llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads); + llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0)); last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); n_past += n_prompt_tokens; @@ -79,7 +73,7 @@ int main(int argc, char ** argv) { for (auto i = 0; i < params.n_predict; i++) { auto * logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(ctx); + auto n_vocab = llama_n_vocab(model); std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { @@ -91,7 +85,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); - if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx); llama_free_model(model); @@ -106,7 +100,7 @@ int main(int argc, char ** argv) { llama_free(ctx); // make new context - auto * ctx2 = llama_new_context_with_model(model, lparams); + auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); // Load state (rng, logits, embedding and kv_cache) from file { @@ -139,7 +133,7 @@ int main(int argc, char ** argv) { // second run for (auto i = 0; i < params.n_predict; i++) { auto * logits = llama_get_logits(ctx2); - auto n_vocab = llama_n_vocab(ctx2); + auto n_vocab = llama_n_vocab(model); std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { @@ -151,7 +145,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); - if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx2); llama_free_model(model); diff --git a/examples/server/README.md b/examples/server/README.md index 517608046..d409e8408 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -4,14 +4,14 @@ This example demonstrates a simple HTTP API server and a simple web front end to Command line options: -- `--threads N`, `-t N`: Set the number of threads to use during computation. +- `--threads N`, `-t N`: Set the number of threads to use during generation. +- `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. If not specified, the number of threads will be set to the number of threads used for generation. - `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`). - `-m ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses. - `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of 4096. - `-ngl N`, `--n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance. - `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS. - `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS. -- `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS. - `-b N`, `--batch-size N`: Set the batch size for prompt processing. Default: `512`. - `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. Not recommended. - `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9b9624832..fe9a4255e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -200,6 +200,7 @@ struct llama_server_context llama_model *model = nullptr; llama_context *ctx = nullptr; gpt_params params; + int n_ctx; grammar_parser::parse_state parsed_grammar; llama_grammar *grammar = nullptr; @@ -239,7 +240,7 @@ struct llama_server_context num_prompt_tokens = 0; num_tokens_predicted = 0; generated_text = ""; - generated_text.reserve(params.n_ctx); + generated_text.reserve(n_ctx); generated_token_probs.clear(); truncated = false; stopped_eos = false; @@ -265,8 +266,8 @@ struct llama_server_context LOG_ERROR("unable to load model", {{"model", params_.model}}); return false; } - - last_n_tokens.resize(params.n_ctx); + n_ctx = llama_n_ctx(ctx); + last_n_tokens.resize(n_ctx); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); return true; } @@ -351,19 +352,19 @@ struct llama_server_context { params.n_keep = (int)num_prompt_tokens; } - params.n_keep = std::min(params.n_ctx - 4, params.n_keep); + params.n_keep = std::min(n_ctx - 4, params.n_keep); // if input prompt is too big, truncate like normal - if (num_prompt_tokens >= (size_t)params.n_ctx) + if (num_prompt_tokens >= (size_t)n_ctx) { - const int n_left = (params.n_ctx - params.n_keep) / 2; + const int n_left = (n_ctx - params.n_keep) / 2; std::vector new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); LOG_VERBOSE("input truncated", { - {"n_ctx", params.n_ctx}, + {"n_ctx", n_ctx}, {"n_keep", params.n_keep}, {"n_left", n_left}, {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, @@ -413,7 +414,7 @@ struct llama_server_context completion_token_output result; result.tok = -1; - if (embd.size() >= (size_t)params.n_ctx) + if (embd.size() >= (size_t)n_ctx) { // Shift context @@ -433,7 +434,7 @@ struct llama_server_context truncated = true; LOG_VERBOSE("input truncated", { - {"n_ctx", params.n_ctx}, + {"n_ctx", n_ctx}, {"n_keep", params.n_keep}, {"n_left", n_left}, }); @@ -447,12 +448,11 @@ struct llama_server_context n_eval = params.n_batch; } - if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads)) + if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0))) { LOG_ERROR("failed to eval", { {"n_eval", n_eval}, {"n_past", n_past}, - {"n_threads", params.n_threads}, {"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, }); has_next_token = false; @@ -470,11 +470,11 @@ struct llama_server_context // out of user input, sample next token const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; + const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(model) : params.top_k; const float top_p = params.top_p; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; - const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n; + const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; const float repeat_penalty = params.repeat_penalty; const float alpha_presence = params.presence_penalty; const float alpha_frequency = params.frequency_penalty; @@ -486,7 +486,7 @@ struct llama_server_context { auto *logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(ctx); + auto n_vocab = llama_n_vocab(model); // Apply params.logit_bias map for (const auto &it : params.logit_bias) @@ -505,7 +505,7 @@ struct llama_server_context // Apply penalties float nl_logit = logits[llama_token_nl(ctx)]; - auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); + auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); llama_sample_repetition_penalty(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_repeat, repeat_penalty); @@ -690,7 +690,7 @@ struct llama_server_context std::vector getEmbedding() { - static const int n_embd = llama_n_embd(ctx); + static const int n_embd = llama_n_embd(model); if (!params.embedding) { LOG_WARNING("embedding disabled", { @@ -734,7 +734,6 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" -ts SPLIT --tensor-split SPLIT\n"); printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n"); - printf(" -lv, --low-vram don't allocate VRAM scratch buffer\n"); printf(" -nommq, --no-mul-mat-q\n"); printf(" use cuBLAS instead of custom mul_mat_q CUDA kernels.\n"); printf(" Not recommended since this is both slower and uses more VRAM.\n"); @@ -918,14 +917,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } #else LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {}); -#endif // GGML_USE_CUBLAS - } - else if (arg == "--low-vram" || arg == "-lv") - { -#ifdef GGML_USE_CUBLAS - params.low_vram = true; -#else - LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n", {}); #endif // GGML_USE_CUBLAS } else if (arg == "--no-mul-mat-q" || arg == "-nommq") @@ -1031,7 +1022,7 @@ static json format_generation_settings(llama_server_context &llama) eos_bias->second < 0.0f && std::isinf(eos_bias->second); return json{ - {"n_ctx", llama.params.n_ctx}, + {"n_ctx", llama.n_ctx}, {"model", llama.params.model_alias}, {"seed", llama.params.seed}, {"temp", llama.params.temp}, @@ -1191,7 +1182,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla const auto &logit_bias = body.find("logit_bias"); if (logit_bias != body.end() && logit_bias->is_array()) { - const int n_vocab = llama_n_vocab(llama.ctx); + const int n_vocab = llama_n_vocab(llama.model); for (const auto &el : *logit_bias) { if (el.is_array() && el.size() == 2 && el[0].is_number_integer()) @@ -1324,6 +1315,7 @@ int main(int argc, char **argv) {"commit", BUILD_COMMIT}}); LOG_INFO("system info", { {"n_threads", params.n_threads}, + {"n_threads_batch", params.n_threads_batch}, {"total_threads", std::thread::hardware_concurrency()}, {"system_info", llama_print_system_info()}, }); @@ -1387,7 +1379,7 @@ int main(int argc, char **argv) if (llama.params.n_beams) { // Fill llama.generated_token_probs vector with final beam. llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams, - llama.n_past, llama.n_remain, llama.params.n_threads); + llama.n_past, llama.n_remain); // Translate llama.generated_token_probs to llama.generated_text. append_to_generated_text_from_generated_token_probs(llama); } else { diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 1616a4a75..24fb16b78 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -33,18 +33,28 @@ int main(int argc, char ** argv) { llama_backend_init(params.numa); - llama_context_params ctx_params = llama_context_default_params(); + // initialize the model - ctx_params.seed = 1234; - ctx_params.n_ctx = 2048; + llama_model_params model_params = llama_model_default_params(); - llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); + // model_params.n_gpu_layers = 99; // offload all layers to the GPU + + llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); return 1; } + // initialize the context + + llama_context_params ctx_params = llama_context_default_params(); + + ctx_params.seed = 1234; + ctx_params.n_ctx = 2048; + ctx_params.n_threads = params.n_threads; + ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + llama_context * ctx = llama_new_context_with_model(model, ctx_params); if (ctx == NULL) { @@ -97,7 +107,7 @@ int main(int argc, char ** argv) { // llama_decode will output logits only for the last token of the prompt batch.logits[batch.n_tokens - 1] = true; - if (llama_decode(ctx, batch, params.n_threads) != 0) { + if (llama_decode(ctx, batch) != 0) { LOG_TEE("%s: llama_decode() failed\n", __func__); return 1; } @@ -112,7 +122,7 @@ int main(int argc, char ** argv) { while (n_cur <= n_len) { // sample the next token { - auto n_vocab = llama_n_vocab(ctx); + auto n_vocab = llama_n_vocab(model); auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); std::vector candidates; @@ -154,7 +164,7 @@ int main(int argc, char ** argv) { n_cur += 1; // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch, params.n_threads)) { + if (llama_decode(ctx, batch)) { fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); return 1; } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 2445d78dc..c5e5b234f 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -70,16 +70,16 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0), params.n_threads); - llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0), params.n_threads); - llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0), params.n_threads); + llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); + llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); + llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0)); const auto t_enc_end = ggml_time_us(); // the 2 models should have the same vocab const int n_ctx = llama_n_ctx(ctx_tgt); - const int n_vocab = llama_n_vocab(ctx_tgt); - //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft)); + const int n_vocab = llama_n_vocab(model_tgt); + //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft)); // how many tokens to draft each time int n_draft = params.n_draft; @@ -173,7 +173,7 @@ int main(int argc, char ** argv) { } llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx); - llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads); + llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0)); ++n_past_dft; // heuristic for n_draft @@ -258,7 +258,7 @@ int main(int argc, char ** argv) { // evaluate the drafted token on the draft model llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx); - llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads); + llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0)); ++n_past_cur; if (grammar_dft != NULL) { @@ -268,7 +268,7 @@ int main(int argc, char ** argv) { // evaluate the target model on the drafted tokens llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx); - llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads); + llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0)); ++n_past_tgt; // the first token is always proposed by the traget model before the speculation loop diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index d5205aff6..a9cf8a381 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -976,14 +976,16 @@ int main(int argc, char ** argv) { printf("%s: seed: %u\n", __func__, params.common.seed); srand(params.common.seed); - struct llama_context_params llama_params = llama_context_default_params(); - llama_params.vocab_only = true; + struct llama_model_params mparams = llama_model_default_params(); + mparams.vocab_only = true; - struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params); - struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params); + struct llama_context_params cparams = llama_context_default_params(); + + struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, mparams); + struct llama_context * lctx = llama_new_context_with_model(lmodel, cparams); struct my_llama_model model; - model.hparams.n_vocab = llama_n_vocab(lctx); + model.hparams.n_vocab = llama_n_vocab(lmodel); model.hparams.n_ctx = params.common.n_ctx; model.hparams.n_embd = params.n_embd; model.hparams.n_head = params.n_head; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 29fb7abd4..86d1fe203 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,3 +1,4 @@ +#include #include #include #include @@ -467,7 +468,7 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; static bool g_mul_mat_q = true; static void * g_scratch_buffer = nullptr; -static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default +static size_t g_scratch_size = 0; // disabled by default static size_t g_scratch_offset = 0; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; @@ -6738,14 +6739,10 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te const int64_t ne1 = dst->ne[1]; // TODO: find the optimal values for these - if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && - src1->type == GGML_TYPE_F32 && - dst->type == GGML_TYPE_F32 && - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { - return true; - } - - return false; + return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32); } static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ @@ -6901,6 +6898,8 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream); } else { + fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); GGML_ASSERT(false); } @@ -7198,7 +7197,12 @@ void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) { } void ggml_cuda_set_scratch_size(const size_t scratch_size) { - g_scratch_size = scratch_size; + // this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously + // it still won't always work as expected, but it's better than nothing + if (scratch_size > g_scratch_size) { + ggml_cuda_free_scratch(); + } + g_scratch_size = std::max(g_scratch_size, scratch_size); } void ggml_cuda_free_scratch() { diff --git a/llama.cpp b/llama.cpp index 7668cb1a7..685712d17 100644 --- a/llama.cpp +++ b/llama.cpp @@ -887,10 +887,10 @@ static void llama_nop(struct ggml_tensor * tensor) { // don't offload by default static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) { std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size()); + const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_token_to_piece(ctx, token, result.data(), result.size()); + int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); @@ -931,9 +931,9 @@ static const size_t MB = kB*kB; static const size_t GB = kB*kB*kB; struct llama_hparams { + bool vocab_only; uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on - uint32_t n_ctx; // context size used during inference uint32_t n_embd; uint32_t n_head; uint32_t n_head_kv; @@ -944,8 +944,8 @@ struct llama_hparams { float f_norm_eps; float f_norm_rms_eps; - float rope_freq_base; - float rope_freq_scale; + float rope_freq_base_train; + float rope_freq_scale_train; bool operator!=(const llama_hparams & other) const { return static_cast(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT @@ -962,15 +962,18 @@ struct llama_hparams { uint32_t n_embd_gqa() const { return n_embd/n_gqa(); } +}; - size_t kv_size() const { - size_t result = 2ull; - result *= (size_t) n_embd_gqa(); - result *= (size_t) n_ctx; - result *= (size_t) n_layer; - result *= sizeof(ggml_fp16_t); - return result; - } +struct llama_cparams { + uint32_t n_ctx; // context size used during inference + uint32_t n_batch; + uint32_t n_threads; // number of threads to use for generation + uint32_t n_threads_batch; // number of threads to use for batch processing + + float rope_freq_base; + float rope_freq_scale; + + bool mul_mat_q; }; struct llama_layer { @@ -1148,11 +1151,8 @@ struct llama_model { }; struct llama_context { - llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {} + llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} ~llama_context() { - if (model_owner) { - delete &model; - } #ifdef GGML_USE_METAL if (ctx_metal) { ggml_metal_free(ctx_metal); @@ -1163,27 +1163,26 @@ struct llama_context { } } + llama_cparams cparams; + + const llama_model & model; + + // key + value cache for the self attention + struct llama_kv_cache kv_self; + std::mt19937 rng; bool has_evaluated_once = false; + int64_t t_start_us; + int64_t t_load_us; int64_t t_sample_us = 0; - int64_t t_eval_us = 0; int64_t t_p_eval_us = 0; + int64_t t_eval_us = 0; int32_t n_sample = 0; // number of tokens sampled - int32_t n_eval = 0; // number of eval calls int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) - - const llama_model & model; - - bool model_owner = false; - - int64_t t_load_us; - int64_t t_start_us; - - // key + value cache for the self attention - struct llama_kv_cache kv_self; + int32_t n_eval = 0; // number of eval calls // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -1218,10 +1217,10 @@ static bool llama_kv_cache_init( const struct llama_hparams & hparams, struct llama_kv_cache & cache, ggml_type wtype, + uint32_t n_ctx, int n_gpu_layers) { const uint32_t n_embd = hparams.n_embd_gqa(); const uint32_t n_layer = hparams.n_layer; - const uint32_t n_ctx = hparams.n_ctx; const int64_t n_mem = n_layer*n_ctx; const int64_t n_elements = n_embd*n_mem; @@ -1255,11 +1254,20 @@ static bool llama_kv_cache_init( (void) n_gpu_layers; #ifdef GGML_USE_CUBLAS + size_t vram_kv_cache = 0; + if (n_gpu_layers > (int)n_layer + 1) { ggml_cuda_assign_buffers_no_scratch(cache.v); + LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__); + vram_kv_cache += ggml_nbytes(cache.v); } if (n_gpu_layers > (int)n_layer + 2) { ggml_cuda_assign_buffers_no_scratch(cache.k); + LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__); + vram_kv_cache += ggml_nbytes(cache.k); + } + if (vram_kv_cache > 0) { + LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0); } #endif // GGML_USE_CUBLAS @@ -1715,7 +1723,7 @@ struct llama_model_loader { lmlock->grow_to(size_lock); } break; -#if defined(GGML_USE_CUBLAS) +#ifdef GGML_USE_CUBLAS case GGML_BACKEND_GPU: case GGML_BACKEND_GPU_SPLIT: // old code: @@ -1748,7 +1756,15 @@ struct llama_model_loader { // load LLaMA models // -static std::string llama_model_ftype_name(enum llama_ftype ftype) { +static std::string llama_model_arch_name(llm_arch arch) { + auto it = LLM_ARCH_NAMES.find(arch); + if (it == LLM_ARCH_NAMES.end()) { + return "unknown"; + } + return it->second; +} + +static std::string llama_model_ftype_name(llama_ftype ftype) { if (ftype & LLAMA_FTYPE_GUESSED) { return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; } @@ -1804,10 +1820,7 @@ static void llm_load_arch(llama_model_loader & ml, llama_model & model) { static void llm_load_hparams( llama_model_loader & ml, - llama_model & model, - int n_ctx, - float rope_freq_base, - float rope_freq_scale) { + llama_model & model) { struct gguf_context * ctx = ml.ctx_gguf; const auto kv = LLM_KV(model.arch); @@ -1818,29 +1831,25 @@ static void llm_load_hparams( GGUF_GET_KEY(ctx, model.name, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_GENERAL_NAME)); // get hparams kv - GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST)); - GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH)); - GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH)); - GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); - GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); - GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); + GGUF_GET_KEY(ctx, hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, kv(LLM_KV_TOKENIZER_LIST)); + GGUF_GET_KEY(ctx, hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); + GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); + GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); // n_head_kv is optional, default to n_head hparams.n_head_kv = hparams.n_head; GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV)); // rope_freq_base (optional) - if (rope_freq_base == 0.0f) { - rope_freq_base = 10000.0f; - GGUF_GET_KEY(ctx, rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); - } + hparams.rope_freq_base_train = 10000.0f; + GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); // rope_freq_scale (inverse of the kv) is optional - if (rope_freq_scale == 0.0f) { - float ropescale = 1.0f; - GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); - rope_freq_scale = 1.0f/ropescale; - } + float ropescale = 1.0f; + GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); + hparams.rope_freq_scale_train = 1.0f/ropescale; // sanity check for n_rot (optional) { @@ -1907,10 +1916,6 @@ static void llm_load_hparams( }; model.ftype = ml.ftype; - - hparams.n_ctx = n_ctx; - hparams.rope_freq_base = rope_freq_base; - hparams.rope_freq_scale = rope_freq_scale; } // TODO: This should probably be in llama.h @@ -2034,31 +2039,30 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { const auto & vocab = model.vocab; // hparams - LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); - LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix - LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); - LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size()); - LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); - LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim - LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); - LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); - LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); - LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); - LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); - LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); - LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); + LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); + LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix + LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); + LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size()); + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); + LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); + LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); + LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); + LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9); if (ml.n_bytes < GB) { - LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); + LLAMA_LOG_INFO("%s: model size = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); } else { - LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); + LLAMA_LOG_INFO("%s: model size = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements); } // general kv @@ -2076,13 +2080,9 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { static void llm_load_tensors( llama_model_loader & ml, llama_model & model, - int n_batch, int n_gpu_layers, int main_gpu, const float * tensor_split, - const bool mul_mat_q, - bool low_vram, - ggml_type memory_type, bool use_mlock, llama_progress_callback progress_callback, void * progress_callback_user_data) { @@ -2121,11 +2121,9 @@ static void llm_load_tensors( } (void) main_gpu; - (void) mul_mat_q; -#if defined(GGML_USE_CUBLAS) +#ifdef GGML_USE_CUBLAS LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__); ggml_cuda_set_main_device(main_gpu); - ggml_cuda_set_mul_mat_q(mul_mat_q); #define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT #elif defined(GGML_USE_CLBLAST) @@ -2160,9 +2158,9 @@ static void llm_load_tensors( // norm is not performance relevant on its own but keeping it in VRAM reduces data copying // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 - backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; @@ -2226,9 +2224,9 @@ static void llm_load_tensors( // norm is not performance relevant on its own but keeping it in VRAM reduces data copying // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 - backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; @@ -2296,9 +2294,9 @@ static void llm_load_tensors( // norm is not performance relevant on its own but keeping it in VRAM reduces data copying // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 - backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; @@ -2373,9 +2371,9 @@ static void llm_load_tensors( // norm is not performance relevant on its own but keeping it in VRAM reduces data copying // on Windows however this is detrimental unless everything is on the GPU #ifndef _WIN32 - backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = LLAMA_BACKEND_OFFLOAD; #else - backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; #endif // _WIN32 backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; @@ -2447,20 +2445,12 @@ static void llm_load_tensors( // print memory requirements { - const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; - // this is the total memory required to run the inference size_t mem_required = ctx_size + mmapped_size - vram_weights; // weights in VRAM not in memory - // this is the memory required by one llama_state - const size_t mem_required_state = scale*hparams.kv_size(); - - LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, - mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); - - (void) n_batch; + LLAMA_LOG_INFO("%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); @@ -2469,36 +2459,17 @@ static void llm_load_tensors( if (n_gpu_layers > (int) hparams.n_layer) { LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__); } - size_t vram_kv_cache = 0; #ifdef GGML_USE_CUBLAS const int max_backend_supported_layers = hparams.n_layer + 3; - const int max_offloadable_layers = low_vram ? hparams.n_layer + 1 : hparams.n_layer + 3; - if (n_gpu_layers > (int) hparams.n_layer + 1) { - if (low_vram) { - LLAMA_LOG_INFO("%s: cannot offload v cache to GPU due to low VRAM option\n", __func__); - } else { - LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__); - vram_kv_cache += hparams.kv_size() / 2; - } - } - if (n_gpu_layers > (int) hparams.n_layer + 2) { - if (low_vram) { - LLAMA_LOG_WARN("%s: cannot offload k cache to GPU due to low VRAM option\n", __func__); - } else { - LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__); - vram_kv_cache += hparams.kv_size() / 2; - } - } + const int max_offloadable_layers = hparams.n_layer + 3; #elif defined(GGML_USE_CLBLAST) const int max_backend_supported_layers = hparams.n_layer + 1; const int max_offloadable_layers = hparams.n_layer + 1; #endif // GGML_USE_CUBLAS - LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", - __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); - LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n", - __func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up + LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); + LLAMA_LOG_INFO("%s: VRAM used: %.2f MB\n", __func__, vram_weights / 1024.0 / 1024.0); #else (void) n_gpu_layers; #endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) @@ -2511,7 +2482,7 @@ static void llm_load_tensors( } (void) tensor_split; -#if defined(GGML_USE_CUBLAS) +#ifdef GGML_USE_CUBLAS { ggml_cuda_set_tensor_split(tensor_split); } @@ -2533,29 +2504,24 @@ static void llm_load_tensors( static bool llama_model_load( const std::string & fname, llama_model & model, - int n_ctx, - int n_batch, int n_gpu_layers, int main_gpu, const float * tensor_split, - const bool mul_mat_q, - float rope_freq_base, - float rope_freq_scale, - bool low_vram, - ggml_type memory_type, bool use_mmap, bool use_mlock, bool vocab_only, llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - std::unique_ptr ml(new llama_model_loader(fname, use_mmap)); + llama_model_loader ml(fname, use_mmap); - llm_load_arch (*ml, model); - llm_load_hparams(*ml, model, n_ctx, rope_freq_base, rope_freq_scale); - llm_load_vocab (*ml, model); + model.hparams.vocab_only = vocab_only; - llm_load_print_meta(*ml, model); + llm_load_arch (ml, model); + llm_load_hparams(ml, model); + llm_load_vocab (ml, model); + + llm_load_print_meta(ml, model); if (model.hparams.n_vocab != model.vocab.id_to_token.size()) { throw std::runtime_error("vocab size mismatch"); @@ -2567,8 +2533,8 @@ static bool llama_model_load( } llm_load_tensors( - *ml, model, n_batch, n_gpu_layers, - main_gpu, tensor_split, mul_mat_q, low_vram, memory_type, + ml, model, n_gpu_layers, + main_gpu, tensor_split, use_mlock, progress_callback, progress_callback_user_data); } catch (const std::exception & err) { LLAMA_LOG_ERROR("error loading model: %s\n", err.what()); @@ -2583,6 +2549,7 @@ static struct ggml_cgraph * llm_build_llama( const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -2590,7 +2557,7 @@ static struct ggml_cgraph * llm_build_llama( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); @@ -2598,8 +2565,8 @@ static struct ggml_cgraph * llm_build_llama( GGML_ASSERT(n_embd_head == hparams.n_rot); - const float freq_base = hparams.rope_freq_base; - const float freq_scale = hparams.rope_freq_scale; + const float freq_base = cparams.rope_freq_base; + const float freq_scale = cparams.rope_freq_scale; const float norm_rms_eps = hparams.f_norm_rms_eps; const int n_gpu_layers = model.n_gpu_layers; @@ -2657,9 +2624,6 @@ static struct ggml_cgraph * llm_build_llama( // offload functions set the tensor output backend to GPU // tensors are GPU-accelerated if any input or the output has been offloaded - // - // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal - // in that case ggml_cuda_assign_buffers has no effect offload_func_t offload_func_nr = llama_nop; // nr = non-repeating offload_func_t offload_func_kq = llama_nop; offload_func_t offload_func_v = llama_nop; @@ -2975,6 +2939,7 @@ static struct ggml_cgraph * llm_build_baichaun( const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -2982,7 +2947,7 @@ static struct ggml_cgraph * llm_build_baichaun( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); @@ -2990,8 +2955,8 @@ static struct ggml_cgraph * llm_build_baichaun( GGML_ASSERT(n_embd_head == hparams.n_rot); - const float freq_base = hparams.rope_freq_base; - const float freq_scale = hparams.rope_freq_scale; + const float freq_base = cparams.rope_freq_base; + const float freq_scale = cparams.rope_freq_scale; const float norm_rms_eps = hparams.f_norm_rms_eps; const int n_gpu_layers = model.n_gpu_layers; @@ -3047,9 +3012,6 @@ static struct ggml_cgraph * llm_build_baichaun( // offload functions set the tensor output backend to GPU // tensors are GPU-accelerated if any input or the output has been offloaded - // - // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal - // in that case ggml_cuda_assign_buffers has no effect offload_func_t offload_func_nr = llama_nop; // nr = non-repeating offload_func_t offload_func_kq = llama_nop; offload_func_t offload_func_v = llama_nop; @@ -3382,6 +3344,7 @@ static struct ggml_cgraph * llm_build_falcon( const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -3389,7 +3352,7 @@ static struct ggml_cgraph * llm_build_falcon( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); @@ -3397,8 +3360,8 @@ static struct ggml_cgraph * llm_build_falcon( GGML_ASSERT(n_embd_head == hparams.n_rot); - const float freq_base = hparams.rope_freq_base; - const float freq_scale = hparams.rope_freq_scale; + const float freq_base = cparams.rope_freq_base; + const float freq_scale = cparams.rope_freq_scale; const float norm_eps = hparams.f_norm_eps; const int n_gpu_layers = model.n_gpu_layers; @@ -3457,9 +3420,6 @@ static struct ggml_cgraph * llm_build_falcon( // offload functions set the tensor output backend to GPU // tensors are GPU-accelerated if any input or the output has been offloaded - // - // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal - // in that case ggml_cuda_assign_buffers has no effect offload_func_t offload_func_nr = llama_nop; // nr = non-repeating offload_func_t offload_func_kq = llama_nop; offload_func_t offload_func_v = llama_nop; @@ -3753,6 +3713,7 @@ static struct ggml_cgraph * llm_build_starcoder( const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; const auto & kv_self = lctx.kv_self; @@ -3760,7 +3721,7 @@ static struct ggml_cgraph * llm_build_starcoder( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = hparams.n_ctx; + const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_head = hparams.n_embd_head(); @@ -4037,8 +3998,7 @@ static struct ggml_cgraph * llama_build_graph( // static int llama_decode_internal( llama_context & lctx, - llama_batch batch, - int n_threads) { + llama_batch batch) { const uint32_t n_tokens = batch.n_tokens; if (n_tokens == 0) { @@ -4046,6 +4006,15 @@ static int llama_decode_internal( return -1; } + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + const auto n_batch = cparams.n_batch; + + GGML_ASSERT(n_tokens <= n_batch); + + int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT const int64_t t_start_us = ggml_time_us(); @@ -4058,9 +4027,6 @@ static int llama_decode_internal( GGML_ASSERT(n_threads > 0); - const auto & model = lctx.model; - const auto & hparams = model.hparams; - auto & kv_self = lctx.kv_self; GGML_ASSERT(!!kv_self.ctx); @@ -4103,7 +4069,7 @@ static int llama_decode_internal( // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important //kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA? - kv_self.n = std::min((int32_t) hparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self))); + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self))); //printf("kv_self.n = %d\n", kv_self.n); @@ -4128,6 +4094,8 @@ static int llama_decode_internal( ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); } } + + ggml_cuda_set_mul_mat_q(cparams.mul_mat_q); #endif // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -5416,7 +5384,7 @@ void llama_sample_classifier_free_guidance( GGML_ASSERT(ctx); - auto n_vocab = llama_n_vocab(ctx); + auto n_vocab = llama_n_vocab(llama_get_model(ctx)); GGML_ASSERT(n_vocab == (int)candidates->size); GGML_ASSERT(!candidates->sorted); @@ -5445,7 +5413,7 @@ void llama_sample_classifier_free_guidance( llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { GGML_ASSERT(ctx); - auto N = float(llama_n_vocab(ctx)); + auto N = float(llama_n_vocab(llama_get_model(ctx))); int64_t t_start_sample_us; t_start_sample_us = ggml_time_us(); @@ -5632,7 +5600,7 @@ struct llama_logit_info { }; llama_logit_info(llama_context * ctx) : logits(llama_get_logits(ctx)) - , n_vocab(llama_n_vocab(ctx)) + , n_vocab(llama_n_vocab(llama_get_model(ctx))) , max_l(*std::max_element(logits, logits + n_vocab)) , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l})) { } @@ -5670,7 +5638,6 @@ struct llama_beam_search_data { size_t n_beams; int n_past; int n_predict; - int n_threads; std::vector beams; std::vector next_beams; @@ -5680,12 +5647,11 @@ struct llama_beam_search_data { // Used to communicate to/from callback on beams state. std::vector beam_views; - llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict, int n_threads) + llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict) : ctx(ctx) , n_beams(n_beams) , n_past(n_past) , n_predict(n_predict) - , n_threads(n_threads) , beam_views(n_beams) { beams.reserve(n_beams); next_beams.reserve(n_beams); @@ -5722,7 +5688,7 @@ struct llama_beam_search_data { } else { // beam is not at end-of-sentence, so branch with next top_k tokens. if (!beam.tokens.empty()) { - llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0), n_threads); + llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0)); } llama_logit_info logit_info(ctx); std::vector next_tokens = logit_info.top_k(n_beams); @@ -5796,7 +5762,7 @@ struct llama_beam_search_data { callback(callback_data, get_beams_state(false)); // Sets common_prefix_length update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed. if (common_prefix_length) { - llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0), n_threads); + llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0)); n_past += common_prefix_length; } // Zero-out next_beam probabilities to place them last in following min-heap. @@ -5837,11 +5803,11 @@ struct llama_beam_search_data { void llama_beam_search(llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, - size_t n_beams, int n_past, int n_predict, int n_threads) { + size_t n_beams, int n_past, int n_predict) { assert(ctx); const int64_t t_start_sample_us = ggml_time_us(); - llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict, n_threads); + llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict); beam_search_data.loop(callback, callback_data); @@ -6061,11 +6027,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s nthread = std::thread::hardware_concurrency(); } - std::unique_ptr ml(new llama_model_loader(fname_inp, /*use_mmap*/ false)); + llama_model_loader ml(fname_inp, /*use_mmap*/ false); llama_model model; - llm_load_arch(*ml, model); - llm_load_hparams(*ml, model, 0, 0, 0); + llm_load_arch(ml, model); + llm_load_hparams(ml, model); if (params->only_copy) { ftype = model.ftype; @@ -6075,7 +6041,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s struct gguf_context * ctx_out = gguf_init_empty(); // copy the KV pairs from the input file - gguf_set_kv (ctx_out, ml->ctx_gguf); + gguf_set_kv (ctx_out, ml.ctx_gguf); gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); gguf_set_val_u32(ctx_out, "general.file_type", ftype); @@ -6083,8 +6049,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s int n_attention_wv = 0; int n_feed_forward_w2 = 0; - for (int i = 0; i < ml->n_tensors; ++i) { - struct ggml_tensor * meta = ml->get_tensor_meta(i); + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * meta = ml.get_tensor_meta(i); const std::string name = ggml_get_name(meta); @@ -6120,8 +6086,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s std::vector> f32_conv_buf; // populate the original tensors so we get an initial meta data - for (int i = 0; i < ml->n_tensors; ++i) { - struct ggml_tensor * meta = ml->get_tensor_meta(i); + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * meta = ml.get_tensor_meta(i); gguf_add_tensor(ctx_out, meta); } @@ -6134,8 +6100,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s // placeholder for the meta data ::zeros(fout, meta_size); - for (int i = 0; i < ml->n_tensors; ++i) { - struct ggml_tensor * tensor = ml->get_tensor_meta(i); + for (int i = 0; i < ml.n_tensors; ++i) { + struct ggml_tensor * tensor = ml.get_tensor_meta(i); const std::string name = ggml_get_name(tensor); @@ -6143,10 +6109,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s read_data.resize(ggml_nbytes(tensor)); } tensor->data = read_data.data(); - ml->load_data_for(tensor); + ml.load_data_for(tensor); LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", - ++idx, ml->n_tensors, + ++idx, ml.n_tensors, ggml_get_name(tensor), llama_format_tensor_shape(tensor).c_str(), ggml_type_name(tensor->type)); @@ -6296,7 +6262,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } -// TODO: after the GGUF PR, this likely won't work and needs to be updated static int llama_apply_lora_from_file_internal( const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads ) { @@ -6575,33 +6540,40 @@ static int llama_apply_lora_from_file_internal( // // interface implementation // +struct llama_model_params llama_model_default_params() { + struct llama_model_params result = { + /*.n_gpu_layers =*/ 0, + /*.main_gpu =*/ 0, + /*.tensor_split =*/ nullptr, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.vocab_only =*/ false, + /*.use_mmap =*/ true, + /*.use_mlock =*/ false, + }; + +#ifdef GGML_USE_METAL + result.n_gpu_layers = 1; +#endif + + return result; +} struct llama_context_params llama_context_default_params() { struct llama_context_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_ctx =*/ 512, /*.n_batch =*/ 512, - /*.n_gpu_layers =*/ 0, - /*.main_gpu =*/ 0, - /*.tensor_split =*/ nullptr, + /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default + /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, - /*.progress_callback =*/ nullptr, - /*.progress_callback_user_data =*/ nullptr, - /*.low_vram =*/ false, /*.mul_mat_q =*/ true, /*.f16_kv =*/ true, /*.logits_all =*/ false, - /*.vocab_only =*/ false, - /*.use_mmap =*/ true, - /*.use_mlock =*/ false, /*.embedding =*/ false, }; -#ifdef GGML_USE_METAL - result.n_gpu_layers = 1; -#endif - return result; } @@ -6660,13 +6632,11 @@ int64_t llama_time_us(void) { struct llama_model * llama_load_model_from_file( const char * path_model, - struct llama_context_params params) { + struct llama_model_params params) { ggml_time_init(); llama_model * model = new llama_model; - ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - unsigned cur_percentage = 0; if (params.progress_callback == NULL) { params.progress_callback_user_data = &cur_percentage; @@ -6683,9 +6653,9 @@ struct llama_model * llama_load_model_from_file( }; } - if (!llama_model_load(path_model, *model, params.n_ctx, params.n_batch, params.n_gpu_layers, - params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale, - params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, + if (!llama_model_load(path_model, *model, params.n_gpu_layers, + params.main_gpu, params.tensor_split, + params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); delete model; @@ -6709,18 +6679,33 @@ struct llama_context * llama_new_context_with_model( llama_context * ctx = new llama_context(*model); + const auto & hparams = model->hparams; + auto & cparams = ctx->cparams; + + cparams.n_batch = params.n_batch; + cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; + cparams.rope_freq_base = params.rope_freq_base == 0 ? hparams.rope_freq_base_train : params.rope_freq_base; + cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale; + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch; + cparams.mul_mat_q = params.mul_mat_q; + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); } + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); + LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; // reserve memory for context buffers - if (!params.vocab_only) { - if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, params.n_gpu_layers)) { + if (!hparams.vocab_only) { + if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -6731,11 +6716,9 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - const auto & hparams = ctx->model.hparams; - // resized during inference if (params.logits_all) { - ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); + ctx->logits.reserve(cparams.n_ctx*hparams.n_vocab); } else { ctx->logits.reserve(hparams.n_vocab); } @@ -6753,12 +6736,13 @@ struct llama_context * llama_new_context_with_model( ctx->alloc = ggml_allocr_new_measure(tensor_alignment); // build worst-case graph - const uint32_t n_tokens = std::min((int) hparams.n_ctx, params.n_batch); + int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch); + int n_past = cparams.n_ctx - n_tokens; llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, hparams.n_ctx - n_tokens, 0)); + ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0)); #ifdef GGML_USE_METAL - if (params.n_gpu_layers > 0) { + if (model->n_gpu_layers > 0) { ctx->ctx_metal = ggml_metal_init(1); if (!ctx->ctx_metal) { LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); @@ -6773,7 +6757,7 @@ struct llama_context * llama_new_context_with_model( // measure memory requirements for the graph size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; - LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); + LLAMA_LOG_INFO("%s: compute buffer total size = %.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); // recreate allocator with exact memory requirements ggml_allocr_free(ctx->alloc); @@ -6786,24 +6770,42 @@ struct llama_context * llama_new_context_with_model( } #endif #ifdef GGML_USE_CUBLAS - if (params.low_vram) { - LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__); - ggml_cuda_set_scratch_size(0); // disable scratch - } else { - ggml_cuda_set_scratch_size(alloc_size); - LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0); + ggml_cuda_set_scratch_size(alloc_size); + LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0); + + // calculate total VRAM usage + auto add_tensor = [](const ggml_tensor * t, size_t & size) { + if (t->backend == GGML_BACKEND_GPU || t->backend == GGML_BACKEND_GPU_SPLIT) { + size += ggml_nbytes(t); + } + }; + size_t model_vram_size = 0; + for (const auto & kv : model->tensors_by_name) { + add_tensor(kv.second, model_vram_size); } + + size_t kv_vram_size = 0; + add_tensor(ctx->kv_self.k, kv_vram_size); + add_tensor(ctx->kv_self.v, kv_vram_size); + + size_t ctx_vram_size = alloc_size + kv_vram_size; + size_t total_vram_size = model_vram_size + ctx_vram_size; + + LLAMA_LOG_INFO("%s: total VRAM used: %.2f MB (model: %.2f MB, context: %.2f MB)\n", __func__, + total_vram_size / 1024.0 / 1024.0, + model_vram_size / 1024.0 / 1024.0, + ctx_vram_size / 1024.0 / 1024.0); #endif } #ifdef GGML_USE_METAL - if (params.n_gpu_layers > 0) { + if (model->n_gpu_layers > 0) { // this allocates all Metal resources and memory buffers void * data_ptr = NULL; size_t data_size = 0; - if (params.use_mmap) { + if (ctx->model.mapping) { data_ptr = ctx->model.mapping->addr; data_size = ctx->model.mapping->size; } else { @@ -6822,11 +6824,8 @@ struct llama_context * llama_new_context_with_model( return NULL; \ } - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size)); - - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.data, ctx->buf_compute.size, 0)); - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0)); - + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size)); + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0)); LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.data, ctx->buf_alloc.size, 0)); #undef LLAMA_METAL_CHECK_BUF } @@ -6850,63 +6849,37 @@ struct llama_context * llama_new_context_with_model( return ctx; } -static struct llama_context * llama_init_from_file( - const char * path_model, - struct llama_context_params params) { - struct llama_model * model = llama_load_model_from_file(path_model, params); - if (!model) { - return nullptr; - } - - struct llama_context * ctx = llama_new_context_with_model(model, params); - ctx->model_owner = true; - - return ctx; -} - void llama_free(struct llama_context * ctx) { delete ctx; } -int llama_n_vocab(const struct llama_context * ctx) { - return llama_model_n_vocab(&ctx->model); +const llama_model * llama_get_model(const struct llama_context * ctx) { + return &ctx->model; } int llama_n_ctx(const struct llama_context * ctx) { - return llama_model_n_ctx(&ctx->model); + return ctx->cparams.n_ctx; } -int llama_n_ctx_train(const struct llama_context * ctx) { - return llama_model_n_ctx_train(&ctx->model); +enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { + return model->vocab.type; } -int llama_n_embd(const struct llama_context * ctx) { - return llama_model_n_embd(&ctx->model); -} - -enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) { - return ctx->model.vocab.type; -} - -int llama_model_n_vocab(const struct llama_model * model) { +int llama_n_vocab(const struct llama_model * model) { return model->vocab.id_to_token.size(); } -int llama_model_n_ctx(const struct llama_model * model) { - return model->hparams.n_ctx; -} - -int llama_model_n_ctx_train(const struct llama_model * model) { +int llama_n_ctx_train(const struct llama_model * model) { return model->hparams.n_ctx_train; } -int llama_model_n_embd(const struct llama_model * model) { +int llama_n_embd(const struct llama_model * model) { return model->hparams.n_embd; } int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { return snprintf(buf, buf_size, "%s %s %s", - model->name.c_str(), + llama_model_arch_name(model->arch).c_str(), llama_model_type_name(model->type), llama_model_ftype_name(model->ftype).c_str()); } @@ -7131,9 +7104,11 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat { const auto & kv_self = ctx->kv_self; const auto & hparams = ctx->model.hparams; + const auto & cparams = ctx->cparams; + const int n_layer = hparams.n_layer; const int n_embd = hparams.n_embd_gqa(); - const int n_ctx = hparams.n_ctx; + const int n_ctx = cparams.n_ctx; const size_t kv_size = kv_self.buf.size; const int kv_ntok = kv_self.head; @@ -7239,9 +7214,11 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { { const auto & kv_self = ctx->kv_self; const auto & hparams = ctx->model.hparams; + const auto & cparams = ctx->cparams; + const int n_layer = hparams.n_layer; const int n_embd = hparams.n_embd_gqa(); - const int n_ctx = hparams.n_ctx; + const int n_ctx = cparams.n_ctx; size_t kv_size; int kv_ntok; @@ -7378,11 +7355,10 @@ int llama_eval( struct llama_context * ctx, llama_token * tokens, int32_t n_tokens, - int n_past, - int n_threads) { + int n_past) { llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); - const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads); + const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0)); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } @@ -7394,13 +7370,12 @@ int llama_eval_embd( struct llama_context * ctx, float * embd, int32_t n_tokens, - int n_past, - int n_threads) { + int n_past) { llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, }; - const int ret = llama_decode_internal(*ctx, batch, n_threads); + const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } @@ -7408,6 +7383,11 @@ int llama_eval_embd( return ret; } +void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { + ctx->cparams.n_threads = n_threads; + ctx->cparams.n_threads_batch = n_threads_batch; +} + struct llama_batch llama_batch_get_one( llama_token * tokens, int32_t n_tokens, @@ -7452,9 +7432,8 @@ void llama_batch_free(struct llama_batch batch) { int llama_decode( struct llama_context * ctx, - struct llama_batch batch, - int n_threads) { - const int ret = llama_decode_internal(*ctx, batch, n_threads); + struct llama_batch batch) { + const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } @@ -7499,16 +7478,6 @@ llama_token llama_token_nl(const struct llama_context * ctx) { } int llama_tokenize( - struct llama_context * ctx, - const char * text, - int text_len, - llama_token * tokens, - int n_max_tokens, - bool add_bos) { - return llama_tokenize_with_model(&ctx->model, text, text_len, tokens, n_max_tokens, add_bos); -} - -int llama_tokenize_with_model( const struct llama_model * model, const char * text, int text_len, @@ -7529,13 +7498,9 @@ int llama_tokenize_with_model( return res.size(); } -int llama_token_to_piece(const struct llama_context * ctx, llama_token token, char * buf, int length) { - return llama_token_to_piece_with_model(&ctx->model, token, buf, length); -} - // does not write null-terminator to buf -int llama_token_to_piece_with_model(const struct llama_model * model, llama_token token, char * buf, int length) { - if (0 <= token && token < llama_model_n_vocab(model)) { +int llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int length) { + if (0 <= token && token < llama_n_vocab(model)) { if (llama_is_normal_token(model->vocab, token)) { std::string result = model->vocab.id_to_token[token].text; if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) { diff --git a/llama.h b/llama.h index 046284d74..96ff1f09c 100644 --- a/llama.h +++ b/llama.h @@ -149,32 +149,37 @@ extern "C" { llama_seq_id all_seq_id; // used if seq_id == NULL } llama_batch; - struct llama_context_params { - uint32_t seed; // RNG seed, -1 for random - int32_t n_ctx; // text context - int32_t n_batch; // prompt processing batch size - int32_t n_gpu_layers; // number of layers to store in VRAM - int32_t main_gpu; // the GPU that is used for scratch and small tensors - + struct llama_model_params { + int32_t n_gpu_layers; // number of layers to store in VRAM + int32_t main_gpu; // the GPU that is used for scratch and small tensors const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) - // ref: https://github.com/ggerganov/llama.cpp/pull/2054 - float rope_freq_base; // RoPE base frequency - float rope_freq_scale; // RoPE frequency scaling factor - // called with a progress value between 0 and 1, pass NULL to disable llama_progress_callback progress_callback; // context pointer passed to the progress callback void * progress_callback_user_data; // Keep the booleans together to avoid misalignment during copy-by-value. - bool low_vram; // if true, reduce VRAM usage at the cost of performance - bool mul_mat_q; // if true, use experimental mul_mat_q kernels - bool f16_kv; // use fp16 for KV cache - bool logits_all; // the llama_eval() call computes all logits, not just the last one bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible bool use_mlock; // force system to keep model in RAM + }; + + struct llama_context_params { + uint32_t seed; // RNG seed, -1 for random + uint32_t n_ctx; // text context + uint32_t n_batch; // prompt processing batch size + uint32_t n_threads; // number of threads to use for generation + uint32_t n_threads_batch; // number of threads to use for batch processing + + // ref: https://github.com/ggerganov/llama.cpp/pull/2054 + float rope_freq_base; // RoPE base frequency + float rope_freq_scale; // RoPE frequency scaling factor + + // Keep the booleans together to avoid misalignment during copy-by-value. + bool mul_mat_q; // if true, use experimental mul_mat_q kernels + bool f16_kv; // use fp16 for KV cache + bool logits_all; // the llama_eval() call computes all logits, not just the last one bool embedding; // embedding mode only }; @@ -236,6 +241,7 @@ extern "C" { }; // Helpers for getting default parameters + LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); @@ -249,7 +255,7 @@ extern "C" { LLAMA_API struct llama_model * llama_load_model_from_file( const char * path_model, - struct llama_context_params params); + struct llama_model_params params); LLAMA_API void llama_free_model(struct llama_model * model); @@ -266,17 +272,15 @@ extern "C" { LLAMA_API bool llama_mmap_supported (void); LLAMA_API bool llama_mlock_supported(void); - LLAMA_API int llama_n_vocab (const struct llama_context * ctx); + LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); + LLAMA_API int llama_n_ctx (const struct llama_context * ctx); - LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx); - LLAMA_API int llama_n_embd (const struct llama_context * ctx); - LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx); + LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); - LLAMA_API int llama_model_n_vocab (const struct llama_model * model); - LLAMA_API int llama_model_n_ctx (const struct llama_model * model); - LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model); - LLAMA_API int llama_model_n_embd (const struct llama_model * model); + LLAMA_API int llama_n_vocab (const struct llama_model * model); + LLAMA_API int llama_n_ctx_train(const struct llama_model * model); + LLAMA_API int llama_n_embd (const struct llama_model * model); // Get a string describing the model type LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); @@ -409,8 +413,7 @@ extern "C" { struct llama_context * ctx, llama_token * tokens, int32_t n_tokens, - int n_past, - int n_threads), + int n_past), "use llama_decode() instead"); // Same as llama_eval, but use float matrix input directly. @@ -419,8 +422,7 @@ extern "C" { struct llama_context * ctx, float * embd, int32_t n_tokens, - int n_past, - int n_threads), + int n_past), "use llama_decode() instead"); // Return batch for single sequence of tokens starting at pos_0 @@ -452,8 +454,12 @@ extern "C" { // < 0 - error LLAMA_API int llama_decode( struct llama_context * ctx, - struct llama_batch batch, - int n_threads); + struct llama_batch batch); + + // Set the number of threads used for decoding + // n_threads is the number of threads used for generation (single token) + // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) + LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); // Token logits obtained from the last call to llama_eval() // The logits for the last token are stored in the last row @@ -494,14 +500,6 @@ extern "C" { // Returns the number of tokens on success, no more than n_max_tokens // Returns a negative number on failure - the number of tokens that would have been returned LLAMA_API int llama_tokenize( - struct llama_context * ctx, - const char * text, - int text_len, - llama_token * tokens, - int n_max_tokens, - bool add_bos); - - LLAMA_API int llama_tokenize_with_model( const struct llama_model * model, const char * text, int text_len, @@ -514,12 +512,6 @@ extern "C" { // Does not write null terminator to the buffer. // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. LLAMA_API int llama_token_to_piece( - const struct llama_context * ctx, - llama_token token, - char * buf, - int length); - - LLAMA_API int llama_token_to_piece_with_model( const struct llama_model * model, llama_token token, char * buf, @@ -700,15 +692,13 @@ extern "C" { /// @param n_beams Number of beams to use. /// @param n_past Number of tokens already evaluated. /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. - /// @param n_threads Number of threads as passed to llama_eval(). LLAMA_API void llama_beam_search( struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, - int n_predict, - int n_threads); + int n_predict); // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); diff --git a/tests/test-tokenizer-0-falcon.cpp b/tests/test-tokenizer-0-falcon.cpp index 836fb8ad2..d51851e20 100644 --- a/tests/test-tokenizer-0-falcon.cpp +++ b/tests/test-tokenizer-0-falcon.cpp @@ -62,18 +62,20 @@ int main(int argc, char **argv) { // load the vocab { - auto lparams = llama_context_default_params(); + auto mparams = llama_model_default_params(); - lparams.vocab_only = true; + mparams.vocab_only = true; - model = llama_load_model_from_file(fname.c_str(), lparams); + model = llama_load_model_from_file(fname.c_str(), mparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); return 1; } - ctx = llama_new_context_with_model(model, lparams); + auto cparams = llama_context_default_params(); + + ctx = llama_new_context_with_model(model, cparams); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); @@ -82,7 +84,7 @@ int main(int argc, char **argv) { } } - if (llama_vocab_type(ctx) != LLAMA_VOCAB_TYPE_BPE) { + if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) { fprintf(stderr, "%s : error: vocab type is not SPM\n", __func__); llama_free_model(model); llama_free(ctx); diff --git a/tests/test-tokenizer-0-llama.cpp b/tests/test-tokenizer-0-llama.cpp index dfb2e81a9..91c841f7b 100644 --- a/tests/test-tokenizer-0-llama.cpp +++ b/tests/test-tokenizer-0-llama.cpp @@ -64,18 +64,20 @@ int main(int argc, char **argv) { // load the vocab { - auto lparams = llama_context_default_params(); + auto mparams = llama_model_default_params(); - lparams.vocab_only = true; + mparams.vocab_only = true; - model = llama_load_model_from_file(fname.c_str(), lparams); + model = llama_load_model_from_file(fname.c_str(), mparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); return 1; } - ctx = llama_new_context_with_model(model, lparams); + auto cparams = llama_context_default_params(); + + ctx = llama_new_context_with_model(model, cparams); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); @@ -84,7 +86,7 @@ int main(int argc, char **argv) { } } - if (llama_vocab_type(ctx) != LLAMA_VOCAB_TYPE_SPM) { + if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_SPM) { fprintf(stderr, "%s : error: vocab type is not SPM\n", __func__); llama_free_model(model); llama_free(ctx); diff --git a/tests/test-tokenizer-1-llama.cpp b/tests/test-tokenizer-1-llama.cpp index a95d462cf..3b2fc87ac 100644 --- a/tests/test-tokenizer-1-llama.cpp +++ b/tests/test-tokenizer-1-llama.cpp @@ -52,18 +52,20 @@ int main(int argc, char **argv) { // load the vocab { - auto lparams = llama_context_default_params(); + auto mparams = llama_model_default_params(); - lparams.vocab_only = true; + mparams.vocab_only = true; - model = llama_load_model_from_file(fname.c_str(), lparams); + model = llama_load_model_from_file(fname.c_str(), mparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); return 1; } - ctx = llama_new_context_with_model(model, lparams); + auto cparams = llama_context_default_params(); + + ctx = llama_new_context_with_model(model, cparams); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); @@ -72,7 +74,7 @@ int main(int argc, char **argv) { } } - GGML_ASSERT(llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM); + GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); #ifdef _WIN32 // We need this for unicode console support @@ -80,7 +82,7 @@ int main(int argc, char **argv) { atexit([]() { console::cleanup(); }); #endif - const int n_vocab = llama_n_vocab(ctx); + const int n_vocab = llama_n_vocab(model); for (int i = 0; i < n_vocab; ++i) { std::string str = llama_detokenize_spm(ctx, std::vector(1, i));