llama.cpp : split llama_context_params into model and context params (#3301)

* llama.cpp : split llama_context_params into model and context params

ggml-ci

* fix metal build

* fix freq_base/scale default to model value

* llama-bench : keep the same model between tests when possible

* move n_threads to llama_context_params, add n_threads_batch

* fix mpi build

* remove kv_size(), cuda scratch fixes

* remove low-vram option

* add n_threads_batch to system info, refactor to get_system_info()

* add documentation about --threads-batch to the READMEs

* llama-bench fix

* main : fix rope freq/scale warning

* llama.cpp : add llama_get_model
common : add llama_tokenize from model

* remove duplicated ctx/model functions

ggml-ci

* cuda : print total VRAM used
This commit is contained in:
slaren 2023-09-28 21:42:38 +02:00 committed by GitHub
parent 0512d66670
commit 16bc66d947
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 713 additions and 633 deletions

View file

@ -129,6 +129,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
if (params.n_threads <= 0) {
params.n_threads = std::thread::hardware_concurrency();
}
} else if (arg == "-tb" || arg == "--threads-batch") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_threads_batch = std::stoi(argv[i]);
if (params.n_threads_batch <= 0) {
params.n_threads_batch = std::thread::hardware_concurrency();
}
} else if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) {
invalid_param = true;
@ -451,12 +460,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.mul_mat_q = false;
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--low-vram" || arg == "-lv") {
#ifdef GGML_USE_CUBLAS
params.low_vram = true;
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--no-mmap") {
params.use_mmap = false;
@ -630,7 +633,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" (can be specified more than once for multiple prompts).\n");
printf(" --color colorise output to distinguish prompt and user input from generations\n");
printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads);
printf(" -tb N, --threads-batch N\n");
printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n");
printf(" -p PROMPT, --prompt PROMPT\n");
printf(" prompt to start generation with (default: empty)\n");
printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
@ -645,7 +650,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -f FNAME, --file FNAME\n");
printf(" prompt file to start generation.\n");
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
@ -705,7 +710,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -ts SPLIT --tensor-split SPLIT\n");
printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
printf(" -lv, --low-vram don't allocate VRAM scratch buffer\n");
#ifdef GGML_USE_CUBLAS
printf(" -nommq, --no-mul-mat-q\n");
printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
@ -726,6 +730,18 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf("\n");
}
std::string get_system_info(const gpt_params & params) {
std::ostringstream os;
os << "system_info: n_threads = " << params.n_threads;
if (params.n_threads_batch != -1) {
os << " (n_threads_batch = " << params.n_threads_batch << ")";
}
os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
return os.str();
}
std::string gpt_random_prompt(std::mt19937 & rng) {
const int r = rng() % 10;
switch (r) {
@ -749,40 +765,50 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
// Model utils
//
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
auto mparams = llama_model_default_params();
lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch;
if (params.n_gpu_layers != -1) {
lparams.n_gpu_layers = params.n_gpu_layers;
mparams.n_gpu_layers = params.n_gpu_layers;
}
lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split;
lparams.low_vram = params.low_vram;
lparams.mul_mat_q = params.mul_mat_q;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.logits_all = params.logits_all;
lparams.embedding = params.embedding;
lparams.rope_freq_base = params.rope_freq_base;
lparams.rope_freq_scale = params.rope_freq_scale;
mparams.main_gpu = params.main_gpu;
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
return lparams;
return mparams;
}
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto cparams = llama_context_default_params();
cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.mul_mat_q = params.mul_mat_q;
cparams.seed = params.seed;
cparams.f16_kv = params.memory_f16;
cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding;
cparams.rope_freq_base = params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale;
return cparams;
}
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto lparams = llama_context_params_from_gpt_params(params);
auto mparams = llama_model_params_from_gpt_params(params);
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return std::make_tuple(nullptr, nullptr);
}
llama_context * lctx = llama_new_context_with_model(model, lparams);
auto cparams = llama_context_params_from_gpt_params(params);
llama_context * lctx = llama_new_context_with_model(model, cparams);
if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
@ -815,7 +841,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
LOG("warming up the model with an empty run\n");
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_tokens_rm(lctx, -1, -1);
llama_reset_timings(lctx);
}
@ -828,16 +854,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
//
std::vector<llama_token> llama_tokenize(
struct llama_context * ctx,
const struct llama_context * ctx,
const std::string & text,
bool add_bos) {
return llama_tokenize(llama_get_model(ctx), text, add_bos);
}
std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
@ -847,10 +880,10 @@ std::vector<llama_token> llama_tokenize(
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(ctx, token, result.data(), result.size());
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
@ -905,7 +938,7 @@ llama_token llama_sample_token(
std::vector<llama_token_data> & candidates,
int idx) {
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
@ -1191,7 +1224,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
#endif // NDEBUG
fprintf(stream, "model_desc: %s\n", model_desc);
fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(lctx));
fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx)));
#ifdef __OPTIMIZE__
fprintf(stream, "optimize: true\n");
@ -1258,7 +1291,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
}
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
fprintf(stream, "low_vram: %s # default: false\n", params.low_vram ? "true" : "false");
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat);

View file

@ -36,6 +36,7 @@ int32_t get_num_physical_cores();
struct gpt_params {
uint32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores();
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
@ -95,7 +96,6 @@ struct gpt_params {
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
@ -126,6 +126,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
std::string get_system_info(const gpt_params & params);
std::string gpt_random_prompt(std::mt19937 & rng);
void process_escapes(std::string& input);
@ -135,6 +137,7 @@ void process_escapes(std::string& input);
//
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params);
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
//
@ -144,7 +147,12 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
// tokenizes a string into a vector of tokens
// should work similar to Python's `tokenizer.encode`
std::vector<llama_token> llama_tokenize(
struct llama_context * ctx,
const struct llama_context * ctx,
const std::string & text,
bool add_bos);
std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos);

View file

@ -858,7 +858,7 @@ size_t tokenize_file(
out_tokens.resize(buf.size() + n_max_tokens_overhead);
int n_tokens = llama_tokenize(
lctx,
llama_get_model(lctx),
buf.data(),
(int) buf.size(),
out_tokens.data(),
@ -867,7 +867,7 @@ size_t tokenize_file(
if (n_tokens < 0) {
out_tokens.resize(-n_tokens);
n_tokens = llama_tokenize(
lctx,
llama_get_model(lctx),
buf.data(),
(int) buf.size(),
out_tokens.data(),
@ -920,7 +920,7 @@ size_t tokenize_file(
size_t found_max_sample_size = 0;
size_t max_token_text_size = 0;
int n_vocab = llama_n_vocab(lctx);
int n_vocab = llama_n_vocab(llama_get_model(lctx));
for (llama_token token=0; token < n_vocab; ++token) {
max_token_text_size = std::max(
max_token_text_size,
@ -961,7 +961,7 @@ size_t tokenize_file(
// tokenize the sample
tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
int n_tokens = llama_tokenize(lctx,
int n_tokens = llama_tokenize(llama_get_model(lctx),
buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(),
@ -969,7 +969,7 @@ size_t tokenize_file(
false);
if (n_tokens < 0) {
tok_sample.resize(-n_tokens);
n_tokens = llama_tokenize(lctx,
n_tokens = llama_tokenize(llama_get_model(lctx),
buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(),

View file

@ -40,20 +40,35 @@ int main(int argc, char ** argv) {
llama_backend_init(params.numa);
llama_context_params ctx_params = llama_context_default_params();
// initialize the model
ctx_params.seed = 1234;
ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301)
ctx_params.n_batch = std::max(n_len, n_parallel);
// ctx_params.n_gpu_layers = 99; // offload all layers to the GPU
llama_model_params model_params = llama_model_default_params();
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
// model_params.n_gpu_layers = 99; // offload all layers to the GPU
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
// tokenize the prompt
std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(model, params.prompt, true);
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
// initialize the context
llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_len, n_parallel);
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
@ -61,13 +76,7 @@ int main(int argc, char ** argv) {
return 1;
}
// tokenize the prompt
std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
const int n_ctx = llama_n_ctx(ctx);
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
@ -106,7 +115,7 @@ int main(int argc, char ** argv) {
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
if (llama_decode(ctx, batch, params.n_threads) != 0) {
if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
@ -146,7 +155,7 @@ int main(int argc, char ** argv) {
continue;
}
auto n_vocab = llama_n_vocab(ctx);
auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
std::vector<llama_token_data> candidates;
@ -210,7 +219,7 @@ int main(int argc, char ** argv) {
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch, params.n_threads)) {
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}

View file

@ -160,7 +160,7 @@ int main(int argc, char ** argv)
int n_past = 0;
if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0), params.n_threads))
if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0)))
{
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
return 1;
@ -170,7 +170,7 @@ int main(int argc, char ** argv)
beam_search_callback_data callback_data{ctx, {}};
size_t const beam_width = static_cast<size_t>(params.n_beams);
int const n_predict = 256;
llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads);
llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict);
std::cout << "\n\n";
for (llama_token const token_id : callback_data.response) {

View file

@ -48,8 +48,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) {
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
fprintf(stderr, "%s\n", get_system_info(params).c_str());
}
struct MyModel * ret = new MyModel();
ret->ctx = ctx;
@ -71,7 +70,7 @@ bool eval_float(void * model, float * input, int N){
MyModel * mymodel = (MyModel*)model;
llama_context * ctx = mymodel->ctx;
gpt_params params = mymodel->params;
int n_emb = llama_n_embd(ctx);
int n_emb = llama_n_embd(llama_get_model(ctx));
int n_past = mymodel->n_past;
int n_batch = N; // params.n_batch;
@ -81,7 +80,7 @@ bool eval_float(void * model, float * input, int N){
n_eval = n_batch;
}
llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
if (llama_decode(ctx, batch, params.n_threads)) {
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
@ -102,7 +101,7 @@ bool eval_tokens(void * model, std::vector<llama_token> tokens) {
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
@ -133,7 +132,7 @@ llama_token sampling_id(struct MyModel* mymodel) {
// out of user input, sample next token
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
@ -149,7 +148,7 @@ llama_token sampling_id(struct MyModel* mymodel) {
llama_token id = 0;
{
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
// Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {

View file

@ -8,7 +8,7 @@ int main(int argc, char** argv) {
auto mymodel = create_mymodel(argc, argv);
int N = 10;
int max_tgt_len = 500;
int n_embd = llama_n_embd(mymodel->ctx);
int n_embd = llama_n_embd(llama_get_model(mymodel->ctx));
// add random float embd to test evaluation
float * data = new float[N*n_embd];

View file

@ -42,17 +42,18 @@ int main(int argc, char ** argv) {
return 1;
}
const int n_ctx_train = llama_n_ctx_train(ctx);
if (params.n_ctx > n_ctx_train) {
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
__func__, n_ctx_train, n_ctx);
}
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
fprintf(stderr, "%s\n", get_system_info(params).c_str());
}
int n_past = 0;
@ -70,15 +71,15 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
}
if (embd_inp.size() > (size_t)params.n_ctx) {
if (embd_inp.size() > (size_t)n_ctx) {
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
__func__, embd_inp.size(), params.n_ctx);
__func__, embd_inp.size(), n_ctx);
return 1;
}
while (!embd_inp.empty()) {
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
@ -86,8 +87,8 @@ int main(int argc, char ** argv) {
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
}
const int n_embd = llama_n_embd(ctx);
const auto embeddings = llama_get_embeddings(ctx);
const int n_embd = llama_n_embd(model);
const auto * embeddings = llama_get_embeddings(ctx);
for (int i = 0; i < n_embd; i++) {
printf("%f ", embeddings[i]);

View file

@ -304,7 +304,7 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
gguf_free(mctx);
}
hparams.n_vocab = llama_model_n_vocab(input);
hparams.n_vocab = llama_n_vocab(input);
hparams.n_ctx = n_ctx;
// get tensors from llama_model (possibly mmapped)
@ -1540,12 +1540,14 @@ int main(int argc, char ** argv) {
printf("%s: seed: %u\n", __func__, params.common.seed);
srand(params.common.seed);
struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = false;
struct llama_model_params llama_mparams = llama_model_default_params();
llama_mparams.vocab_only = false;
printf("%s: model base = '%s'\n", __func__, params.fn_model_base);
struct llama_model * lmodel = llama_load_model_from_file(params.fn_model_base, llama_params);
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
struct llama_model * lmodel = llama_load_model_from_file(params.fn_model_base, llama_mparams);
struct llama_context_params llama_cparams = llama_context_default_params();
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_cparams);
struct my_llama_model model;
init_model(lmodel, &model, params.fn_model_base, params.common.n_ctx);

View file

@ -132,7 +132,6 @@ struct cmd_params {
std::vector<int> n_gpu_layers;
std::vector<int> main_gpu;
std::vector<bool> mul_mat_q;
std::vector<bool> low_vram;
std::vector<std::array<float, LLAMA_MAX_DEVICES>> tensor_split;
int reps;
bool verbose;
@ -149,7 +148,6 @@ static const cmd_params cmd_params_defaults = {
/* n_gpu_layers */ {99},
/* main_gpu */ {0},
/* mul_mat_q */ {true},
/* low_vram */ {false},
/* tensor_split */ {{}},
/* reps */ 5,
/* verbose */ false,
@ -167,9 +165,8 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
printf(" --memory-f32 <0|1> (default: %s)\n", join(cmd_params_defaults.f32_kv, ",").c_str());
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
printf(" -ngl N, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
printf(" -mg i, --main-gpu <n> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -lv, --low-vram <0|1> (default: %s)\n", join(cmd_params_defaults.low_vram, ",").c_str());
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -mmq, --mul-mat-q <0|1> (default: %s)\n", join(cmd_params_defaults.mul_mat_q, ",").c_str());
printf(" -ts, --tensor_split <ts0/ts1/..> \n");
printf(" -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
@ -255,13 +252,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
params.main_gpu = split<int>(argv[i], split_delim);
} else if (arg == "-lv" || arg == "--low-vram") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = split<bool>(argv[i], split_delim);
params.low_vram.insert(params.low_vram.end(), p.begin(), p.end());
} else if (arg == "-mmq" || arg == "--mul-mat-q") {
if (++i >= argc) {
invalid_param = true;
@ -336,7 +326,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
if (params.mul_mat_q.empty()) { params.mul_mat_q = cmd_params_defaults.mul_mat_q; }
if (params.low_vram.empty()) { params.low_vram = cmd_params_defaults.low_vram; }
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; }
@ -353,21 +342,34 @@ struct cmd_params_instance {
int n_gpu_layers;
int main_gpu;
bool mul_mat_q;
bool low_vram;
std::array<float, LLAMA_MAX_DEVICES> tensor_split;
llama_context_params to_llama_params() const {
llama_context_params lparams = llama_context_default_params();
lparams.n_ctx = n_prompt + n_gen;
lparams.n_batch = n_batch;
lparams.f16_kv = !f32_kv;
lparams.n_gpu_layers = n_gpu_layers;
lparams.main_gpu = main_gpu;
lparams.mul_mat_q = mul_mat_q;
lparams.low_vram = low_vram;
lparams.tensor_split = tensor_split.data();
llama_model_params to_llama_mparams() const {
llama_model_params mparams = llama_model_default_params();
return lparams;
mparams.n_gpu_layers = n_gpu_layers;
mparams.main_gpu = main_gpu;
mparams.tensor_split = tensor_split.data();
return mparams;
}
bool equal_mparams(const cmd_params_instance & other) const {
return model == other.model &&
n_gpu_layers == other.n_gpu_layers &&
main_gpu == other.main_gpu &&
tensor_split == other.tensor_split;
}
llama_context_params to_llama_cparams() const {
llama_context_params cparams = llama_context_default_params();
cparams.n_ctx = n_prompt + n_gen;
cparams.n_batch = n_batch;
cparams.f16_kv = !f32_kv;
cparams.mul_mat_q = mul_mat_q;
return cparams;
}
};
@ -375,13 +377,12 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
std::vector<cmd_params_instance> instances;
for (const auto & m : params.model)
for (const auto & nb : params.n_batch)
for (const auto & fk : params.f32_kv)
for (const auto & nl : params.n_gpu_layers)
for (const auto & mg : params.main_gpu)
for (const auto & mmq : params.mul_mat_q)
for (const auto & lv : params.low_vram)
for (const auto & ts : params.tensor_split)
for (const auto & nb : params.n_batch)
for (const auto & fk : params.f32_kv)
for (const auto & mmq : params.mul_mat_q)
for (const auto & nt : params.n_threads) {
cmd_params_instance instance = {
/* .model = */ m,
@ -393,7 +394,6 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
/* .n_gpu_layers = */ nl,
/* .main_gpu = */ mg,
/* .mul_mat_q = */ mmq,
/* .low_vram = */ lv,
/* .tensor_split = */ ts,
};
instances.push_back(instance);
@ -404,6 +404,56 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_params & params) {
std::vector<cmd_params_instance> instances;
#if 1
// this ordering minimizes the number of times that each model needs to be reloaded
for (const auto & m : params.model)
for (const auto & nl : params.n_gpu_layers)
for (const auto & mg : params.main_gpu)
for (const auto & ts : params.tensor_split)
for (const auto & nb : params.n_batch)
for (const auto & fk : params.f32_kv)
for (const auto & mmq : params.mul_mat_q)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
continue;
}
cmd_params_instance instance = {
/* .model = */ m,
/* .n_prompt = */ n_prompt,
/* .n_gen = */ 0,
/* .n_batch = */ nb,
/* .f32_kv = */ fk,
/* .n_threads = */ nt,
/* .n_gpu_layers = */ nl,
/* .main_gpu = */ mg,
/* .mul_mat_q = */ mmq,
/* .tensor_split = */ ts,
};
instances.push_back(instance);
}
for (const auto & n_gen : params.n_gen) {
if (n_gen == 0) {
continue;
}
cmd_params_instance instance = {
/* .model = */ m,
/* .n_prompt = */ 0,
/* .n_gen = */ n_gen,
/* .n_batch = */ nb,
/* .f32_kv = */ fk,
/* .n_threads = */ nt,
/* .n_gpu_layers = */ nl,
/* .main_gpu = */ mg,
/* .mul_mat_q = */ mmq,
/* .tensor_split = */ ts,
};
instances.push_back(instance);
}
}
#else
// this ordering separates the prompt and generation tests
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
continue;
@ -419,6 +469,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
auto instances_gen = get_cmd_params_instances_int(params, n_gen, 0);
instances.insert(instances.end(), instances_gen.begin(), instances_gen.end());
}
#endif
return instances;
}
@ -443,7 +494,6 @@ struct test {
int n_gpu_layers;
int main_gpu;
bool mul_mat_q;
bool low_vram;
std::array<float, LLAMA_MAX_DEVICES> tensor_split;
int n_prompt;
int n_gen;
@ -463,7 +513,6 @@ struct test {
n_gpu_layers = inst.n_gpu_layers;
main_gpu = inst.main_gpu;
mul_mat_q = inst.mul_mat_q;
low_vram = inst.low_vram;
tensor_split = inst.tensor_split;
n_prompt = inst.n_prompt;
n_gen = inst.n_gen;
@ -524,7 +573,7 @@ struct test {
"cpu_info", "gpu_info",
"model_filename", "model_type", "model_size", "model_n_params",
"n_batch", "n_threads", "f16_kv",
"n_gpu_layers", "main_gpu", "mul_mat_q", "low_vram", "tensor_split",
"n_gpu_layers", "main_gpu", "mul_mat_q", "tensor_split",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
"avg_ts", "stddev_ts"
@ -543,7 +592,7 @@ struct test {
return INT;
}
if (field == "cuda" || field == "opencl" || field == "metal" || field == "gpu_blas" || field == "blas" ||
field == "f16_kv" || field == "mul_mat_q" || field == "low_vram") {
field == "f16_kv" || field == "mul_mat_q") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
@ -574,7 +623,7 @@ struct test {
cpu_info, gpu_info,
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
std::to_string(n_batch), std::to_string(n_threads), std::to_string(!f32_kv),
std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), std::to_string(low_vram), tensor_split_str,
std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), tensor_split_str,
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
std::to_string(avg_ts()), std::to_string(stdev_ts())
@ -766,9 +815,6 @@ struct markdown_printer : public printer {
if (params.mul_mat_q.size() > 1 || params.mul_mat_q != cmd_params_defaults.mul_mat_q) {
fields.push_back("mul_mat_q");
}
if (params.low_vram.size() > 1 || params.low_vram != cmd_params_defaults.low_vram) {
fields.push_back("low_vram");
}
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
fields.push_back("tensor_split");
}
@ -889,17 +935,23 @@ struct sql_printer : public printer {
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
std::vector<llama_token> tokens(n_batch, llama_token_bos(ctx));
int n_processed = 0;
llama_set_n_threads(ctx, n_threads, n_threads);
while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch);
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads);
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
n_processed += n_tokens;
}
}
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_token token = llama_token_bos(ctx);
llama_set_n_threads(ctx, n_threads, n_threads);
for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads);
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
}
}
@ -958,17 +1010,25 @@ int main(int argc, char ** argv) {
std::vector<cmd_params_instance> params_instances = get_cmd_params_instances(params);
for (const auto & inst : params_instances) {
// TODO: keep the model between tests when possible
llama_context_params lparams = inst.to_llama_params();
llama_model * lmodel = nullptr;
const cmd_params_instance * prev_inst = nullptr;
llama_model * lmodel = llama_load_model_from_file(inst.model.c_str(), lparams);
if (lmodel == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, inst.model.c_str());
return 1;
for (const auto & inst : params_instances) {
// keep the same model between tests when possible
if (!lmodel || !prev_inst || !inst.equal_mparams(*prev_inst)) {
if (lmodel) {
llama_free_model(lmodel);
}
lmodel = llama_load_model_from_file(inst.model.c_str(), inst.to_llama_mparams());
if (lmodel == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, inst.model.c_str());
return 1;
}
prev_inst = &inst;
}
llama_context * ctx = llama_new_context_with_model(lmodel, lparams);
llama_context * ctx = llama_new_context_with_model(lmodel, inst.to_llama_cparams());
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str());
llama_free_model(lmodel);
@ -1006,9 +1066,10 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
llama_free(ctx);
llama_free_model(lmodel);
}
llama_free_model(lmodel);
p->print_footer();
llama_backend_free();

View file

@ -262,7 +262,8 @@ These options help improve the performance and memory usage of the LLaMA models.
### Number of Threads
- `-t N, --threads N`: Set the number of threads to use during computation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). Using the correct number of threads can greatly improve performance.
- `-t N, --threads N`: Set the number of threads to use during generation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). Using the correct number of threads can greatly improve performance.
- `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. In some systems, it is beneficial to use a higher number of threads during batch processing than during generation. If not specified, the number of threads used for batch processing will be the same as the number of threads used for generation.
### Mlock
@ -305,6 +306,5 @@ These options provide extra functionality and customization when running the LLa
- `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
- `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS.
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.

View file

@ -140,12 +140,17 @@ int main(int argc, char ** argv) {
return 0;
}
if (params.rope_freq_base != 10000.0) {
LOG_TEE("%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
if (params.n_ctx != 0 && params.n_ctx < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
}
if (params.rope_freq_scale != 1.0) {
LOG_TEE("%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
if (params.rope_freq_base != 0.0) {
LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base);
}
if (params.rope_freq_scale != 0.0) {
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
}
LOG_TEE("%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
@ -184,20 +189,19 @@ int main(int argc, char ** argv) {
return 1;
}
const int n_ctx_train = llama_n_ctx_train(ctx);
if (params.n_ctx > n_ctx_train) {
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);
if (n_ctx > n_ctx_train) {
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
} else if (params.n_ctx < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
__func__, n_ctx_train, n_ctx);
}
// print system information
{
LOG_TEE("\n");
LOG_TEE("system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
LOG_TEE("%s\n", get_system_info(params).c_str());
}
std::string path_session = params.path_prompt_cache;
@ -211,7 +215,7 @@ int main(int argc, char ** argv) {
if (fp != NULL) {
std::fclose(fp);
session_tokens.resize(params.n_ctx);
session_tokens.resize(n_ctx);
size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
@ -226,7 +230,7 @@ int main(int argc, char ** argv) {
}
}
const bool add_bos = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM;
LOG("add_bos: %d\n", add_bos);
std::vector<llama_token> embd_inp;
@ -267,9 +271,6 @@ int main(int argc, char ** argv) {
LOG("guidance_offset: %s", log_tostr(guidance_offset));
}
const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);
if ((int) embd_inp.size() > n_ctx - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
return 1;
@ -466,7 +467,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
const int n_vocab = llama_n_vocab(ctx);
const int n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@ -576,7 +577,7 @@ int main(int argc, char ** argv) {
for (int i = 0; i < input_size; i += params.n_batch) {
int n_eval = std::min(input_size - i, params.n_batch);
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0), params.n_threads)) {
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
@ -593,7 +594,7 @@ int main(int argc, char ** argv) {
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0), params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}

View file

@ -108,7 +108,7 @@ int main(int argc, char ** argv) {
fflush(stderr);
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(ctx);
const int n_vocab = llama_n_vocab(model);
std::vector<client> clients(n_clients);
for (size_t i = 0; i < clients.size(); ++i) {
@ -153,7 +153,7 @@ int main(int argc, char ** argv) {
batch.logits[i] = false;
}
if (llama_decode(ctx, batch, params.n_threads) != 0) {
if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
@ -272,7 +272,7 @@ int main(int argc, char ** argv) {
0, 0, 0, // unused
};
const int ret = llama_decode(ctx, batch_view, params.n_threads);
const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size

View file

@ -150,16 +150,18 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
if (int(tokens.size()) < 2*params.n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx,
params.n_ctx);
const int n_ctx = llama_n_ctx(ctx);
if (int(tokens.size()) < 2*n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
n_ctx);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return {std::move(tokens), 0., {}, {}};
}
@ -175,20 +177,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
return {tokens, -1, logit_history, prob_history};
}
const int calc_chunk = params.n_ctx;
const int calc_chunk = n_ctx;
fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk);
if (int(tokens.size()) <= calc_chunk) {
fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__,
tokens.size(), params.n_ctx, params.ppl_stride);
tokens.size(), n_ctx, params.ppl_stride);
return {tokens, -1, logit_history, prob_history};
}
const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch;
int count = 0;
@ -215,7 +217,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const int batch_size = std::min(end - batch_start, n_batch);
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
//fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
@ -250,7 +252,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
}
//fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) {
for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) {
// Calculate probability of next token, given the previous ones.
const std::vector<float> tok_logits(
@ -287,8 +289,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
const bool add_bos = is_spm;
const int n_ctx = llama_n_ctx(ctx);
auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
@ -298,9 +301,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
if (int(tokens.size()) < 2*params.n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*params.n_ctx,
params.n_ctx);
if (int(tokens.size()) < 2*n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx,
n_ctx);
fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size());
return {std::move(tokens), 0., {}, {}};
}
@ -311,10 +314,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<float> prob_history;
prob_history.resize(tokens.size());
const int n_chunk_max = tokens.size() / params.n_ctx;
const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch;
int count = 0;
@ -326,10 +329,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
for (int i = 0; i < n_chunk; ++i) {
const int start = i * params.n_ctx;
const int end = start + params.n_ctx;
const int start = i * n_ctx;
const int end = start + n_ctx;
const int num_batches = (params.n_ctx + n_batch - 1) / n_batch;
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
std::vector<float> logits;
@ -350,7 +353,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
tokens[batch_start] = llama_token_bos(ctx);
}
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
@ -358,7 +361,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;
const auto batch_logits = llama_get_logits(ctx);
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
@ -387,10 +390,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// Example, we have a context window of 512, we will compute perplexity for each of the
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
const int first = params.n_ctx/2;
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first,
const int first = n_ctx/2;
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += params.n_ctx - first - 1;
count += n_ctx - first - 1;
// perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) {
@ -399,7 +402,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
double av = nll/count;
double av2 = nll2/count - av*av;
if (av2 > 0) av2 = sqrt(av2/(count-1));
printf("%8d %.4lf %4lf %4lf\n", i*params.n_ctx, std::exp(nll / count), av, av2);
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
}
fflush(stdout);
}
@ -420,7 +423,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
}
static std::vector<float> hellaswag_evaluate_tokens(
llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab, int n_thread
llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab
) {
std::vector<float> result;
result.reserve(tokens.size() * n_vocab);
@ -428,7 +431,7 @@ static std::vector<float> hellaswag_evaluate_tokens(
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
size_t n_tokens = tokens.size() - i_chunk * n_batch;
n_tokens = std::min(n_tokens, size_t(n_batch));
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {};
}
@ -475,7 +478,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
size_t hs_task_count = prompt_lines.size()/6;
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM;
fprintf(stderr, "================================= is_spm = %d\n", is_spm);
// This is needed as usual for LLaMA models
@ -530,7 +533,8 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
printf("\ntask\tacc_norm\n");
double acc = 0.0f;
const int n_vocab = llama_n_vocab(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
std::vector<std::vector<int>> ending_tokens(4);
@ -558,7 +562,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
auto query_size = query_embd.size();
// Stop if query wont fit the ctx window
if (query_size > (size_t)params.n_ctx) {
if (query_size > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}
@ -571,7 +575,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
// clear the KV cache
llama_kv_cache_tokens_rm(ctx, -1, -1);
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
@ -608,7 +612,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
query_size = query_embd.size();
// Stop if query wont fit the ctx window
if (context_size + query_size > (size_t)params.n_ctx) {
if (context_size + query_size > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}
@ -620,7 +624,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
//}
// Evaluate the query
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads);
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
@ -716,7 +720,7 @@ int main(int argc, char ** argv) {
return 1;
}
const int n_ctx_train = llama_n_ctx_train(ctx);
const int n_ctx_train = llama_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
@ -725,8 +729,7 @@ int main(int argc, char ** argv) {
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
fprintf(stderr, "%s\n", get_system_info(params).c_str());
}
struct results_perplexity results;

View file

@ -309,21 +309,22 @@ int main(int argc, char ** argv) {
llama_context * ctx;
{
auto lparams = llama_context_default_params();
auto mparams = llama_model_default_params();
mparams.use_mlock = false;
lparams.n_ctx = 256;
lparams.seed = 1;
lparams.f16_kv = false;
lparams.use_mlock = false;
model = llama_load_model_from_file(params.model.c_str(), lparams);
model = llama_load_model_from_file(params.model.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
return 1;
}
ctx = llama_new_context_with_model(model, lparams);
auto cparams = llama_context_default_params();
cparams.n_ctx = 256;
cparams.seed = 1;
cparams.f16_kv = false;
ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());

View file

@ -23,23 +23,17 @@ int main(int argc, char ** argv) {
params.n_predict = 16;
}
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
auto n_past = 0;
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
// init
auto * model = llama_load_model_from_file(params.model.c_str(), lparams);
llama_model * model;
llama_context * ctx;
std::tie(model, ctx) = llama_init_from_gpt_params( params );
if (model == nullptr) {
return 1;
}
auto * ctx = llama_new_context_with_model(model, lparams);
if (ctx == nullptr) {
llama_free_model(model);
return 1;
@ -54,7 +48,7 @@ int main(int argc, char ** argv) {
}
// evaluate prompt
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads);
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0));
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
n_past += n_prompt_tokens;
@ -79,7 +73,7 @@ int main(int argc, char ** argv) {
for (auto i = 0; i < params.n_predict; i++) {
auto * logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
auto n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
@ -91,7 +85,7 @@ int main(int argc, char ** argv) {
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str.c_str());
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx);
llama_free_model(model);
@ -106,7 +100,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
// make new context
auto * ctx2 = llama_new_context_with_model(model, lparams);
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
// Load state (rng, logits, embedding and kv_cache) from file
{
@ -139,7 +133,7 @@ int main(int argc, char ** argv) {
// second run
for (auto i = 0; i < params.n_predict; i++) {
auto * logits = llama_get_logits(ctx2);
auto n_vocab = llama_n_vocab(ctx2);
auto n_vocab = llama_n_vocab(model);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
@ -151,7 +145,7 @@ int main(int argc, char ** argv) {
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str.c_str());
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx2);
llama_free_model(model);

View file

@ -4,14 +4,14 @@ This example demonstrates a simple HTTP API server and a simple web front end to
Command line options:
- `--threads N`, `-t N`: Set the number of threads to use during computation.
- `--threads N`, `-t N`: Set the number of threads to use during generation.
- `-tb N, --threads-batch N`: Set the number of threads to use during batch and prompt processing. If not specified, the number of threads will be set to the number of threads used for generation.
- `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`).
- `-m ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses.
- `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of 4096.
- `-ngl N`, `--n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
- `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS.
- `-b N`, `--batch-size N`: Set the batch size for prompt processing. Default: `512`.
- `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. Not recommended.
- `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped.

View file

@ -200,6 +200,7 @@ struct llama_server_context
llama_model *model = nullptr;
llama_context *ctx = nullptr;
gpt_params params;
int n_ctx;
grammar_parser::parse_state parsed_grammar;
llama_grammar *grammar = nullptr;
@ -239,7 +240,7 @@ struct llama_server_context
num_prompt_tokens = 0;
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(params.n_ctx);
generated_text.reserve(n_ctx);
generated_token_probs.clear();
truncated = false;
stopped_eos = false;
@ -265,8 +266,8 @@ struct llama_server_context
LOG_ERROR("unable to load model", {{"model", params_.model}});
return false;
}
last_n_tokens.resize(params.n_ctx);
n_ctx = llama_n_ctx(ctx);
last_n_tokens.resize(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
return true;
}
@ -351,19 +352,19 @@ struct llama_server_context
{
params.n_keep = (int)num_prompt_tokens;
}
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
params.n_keep = std::min(n_ctx - 4, params.n_keep);
// if input prompt is too big, truncate like normal
if (num_prompt_tokens >= (size_t)params.n_ctx)
if (num_prompt_tokens >= (size_t)n_ctx)
{
const int n_left = (params.n_ctx - params.n_keep) / 2;
const int n_left = (n_ctx - params.n_keep) / 2;
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin());
LOG_VERBOSE("input truncated", {
{"n_ctx", params.n_ctx},
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
@ -413,7 +414,7 @@ struct llama_server_context
completion_token_output result;
result.tok = -1;
if (embd.size() >= (size_t)params.n_ctx)
if (embd.size() >= (size_t)n_ctx)
{
// Shift context
@ -433,7 +434,7 @@ struct llama_server_context
truncated = true;
LOG_VERBOSE("input truncated", {
{"n_ctx", params.n_ctx},
{"n_ctx", n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
});
@ -447,12 +448,11 @@ struct llama_server_context
n_eval = params.n_batch;
}
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0)))
{
LOG_ERROR("failed to eval", {
{"n_eval", n_eval},
{"n_past", n_past},
{"n_threads", params.n_threads},
{"embd", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
});
has_next_token = false;
@ -470,11 +470,11 @@ struct llama_server_context
// out of user input, sample next token
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(model) : params.top_k;
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
@ -486,7 +486,7 @@ struct llama_server_context
{
auto *logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
auto n_vocab = llama_n_vocab(model);
// Apply params.logit_bias map
for (const auto &it : params.logit_bias)
@ -505,7 +505,7 @@ struct llama_server_context
// Apply penalties
float nl_logit = logits[llama_token_nl(ctx)];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, repeat_penalty);
@ -690,7 +690,7 @@ struct llama_server_context
std::vector<float> getEmbedding()
{
static const int n_embd = llama_n_embd(ctx);
static const int n_embd = llama_n_embd(model);
if (!params.embedding)
{
LOG_WARNING("embedding disabled", {
@ -734,7 +734,6 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" -ts SPLIT --tensor-split SPLIT\n");
printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
printf(" -lv, --low-vram don't allocate VRAM scratch buffer\n");
printf(" -nommq, --no-mul-mat-q\n");
printf(" use cuBLAS instead of custom mul_mat_q CUDA kernels.\n");
printf(" Not recommended since this is both slower and uses more VRAM.\n");
@ -918,14 +917,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
#else
LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {});
#endif // GGML_USE_CUBLAS
}
else if (arg == "--low-vram" || arg == "-lv")
{
#ifdef GGML_USE_CUBLAS
params.low_vram = true;
#else
LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n", {});
#endif // GGML_USE_CUBLAS
}
else if (arg == "--no-mul-mat-q" || arg == "-nommq")
@ -1031,7 +1022,7 @@ static json format_generation_settings(llama_server_context &llama)
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json{
{"n_ctx", llama.params.n_ctx},
{"n_ctx", llama.n_ctx},
{"model", llama.params.model_alias},
{"seed", llama.params.seed},
{"temp", llama.params.temp},
@ -1191,7 +1182,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
const auto &logit_bias = body.find("logit_bias");
if (logit_bias != body.end() && logit_bias->is_array())
{
const int n_vocab = llama_n_vocab(llama.ctx);
const int n_vocab = llama_n_vocab(llama.model);
for (const auto &el : *logit_bias)
{
if (el.is_array() && el.size() == 2 && el[0].is_number_integer())
@ -1324,6 +1315,7 @@ int main(int argc, char **argv)
{"commit", BUILD_COMMIT}});
LOG_INFO("system info", {
{"n_threads", params.n_threads},
{"n_threads_batch", params.n_threads_batch},
{"total_threads", std::thread::hardware_concurrency()},
{"system_info", llama_print_system_info()},
});
@ -1387,7 +1379,7 @@ int main(int argc, char **argv)
if (llama.params.n_beams) {
// Fill llama.generated_token_probs vector with final beam.
llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
llama.n_past, llama.n_remain, llama.params.n_threads);
llama.n_past, llama.n_remain);
// Translate llama.generated_token_probs to llama.generated_text.
append_to_generated_text_from_generated_token_probs(llama);
} else {

View file

@ -33,18 +33,28 @@ int main(int argc, char ** argv) {
llama_backend_init(params.numa);
llama_context_params ctx_params = llama_context_default_params();
// initialize the model
ctx_params.seed = 1234;
ctx_params.n_ctx = 2048;
llama_model_params model_params = llama_model_default_params();
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
// model_params.n_gpu_layers = 99; // offload all layers to the GPU
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
// initialize the context
llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = 1234;
ctx_params.n_ctx = 2048;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
@ -97,7 +107,7 @@ int main(int argc, char ** argv) {
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
if (llama_decode(ctx, batch, params.n_threads) != 0) {
if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
@ -112,7 +122,7 @@ int main(int argc, char ** argv) {
while (n_cur <= n_len) {
// sample the next token
{
auto n_vocab = llama_n_vocab(ctx);
auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
std::vector<llama_token_data> candidates;
@ -154,7 +164,7 @@ int main(int argc, char ** argv) {
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch, params.n_threads)) {
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}

View file

@ -70,16 +70,16 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us();
// eval the prompt with both models
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0), params.n_threads);
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0), params.n_threads);
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0), params.n_threads);
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0));
const auto t_enc_end = ggml_time_us();
// the 2 models should have the same vocab
const int n_ctx = llama_n_ctx(ctx_tgt);
const int n_vocab = llama_n_vocab(ctx_tgt);
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
const int n_vocab = llama_n_vocab(model_tgt);
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
// how many tokens to draft each time
int n_draft = params.n_draft;
@ -173,7 +173,7 @@ int main(int argc, char ** argv) {
}
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0));
++n_past_dft;
// heuristic for n_draft
@ -258,7 +258,7 @@ int main(int argc, char ** argv) {
// evaluate the drafted token on the draft model
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0));
++n_past_cur;
if (grammar_dft != NULL) {
@ -268,7 +268,7 @@ int main(int argc, char ** argv) {
// evaluate the target model on the drafted tokens
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
++n_past_tgt;
// the first token is always proposed by the traget model before the speculation loop

View file

@ -976,14 +976,16 @@ int main(int argc, char ** argv) {
printf("%s: seed: %u\n", __func__, params.common.seed);
srand(params.common.seed);
struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = true;
struct llama_model_params mparams = llama_model_default_params();
mparams.vocab_only = true;
struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, llama_params);
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
struct llama_context_params cparams = llama_context_default_params();
struct llama_model * lmodel = llama_load_model_from_file(params.fn_vocab_model, mparams);
struct llama_context * lctx = llama_new_context_with_model(lmodel, cparams);
struct my_llama_model model;
model.hparams.n_vocab = llama_n_vocab(lctx);
model.hparams.n_vocab = llama_n_vocab(lmodel);
model.hparams.n_ctx = params.common.n_ctx;
model.hparams.n_embd = params.n_embd;
model.hparams.n_head = params.n_head;

View file

@ -1,3 +1,4 @@
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <limits>
@ -467,7 +468,7 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
static bool g_mul_mat_q = true;
static void * g_scratch_buffer = nullptr;
static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
static size_t g_scratch_size = 0; // disabled by default
static size_t g_scratch_offset = 0;
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
@ -6738,14 +6739,10 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
const int64_t ne1 = dst->ne[1];
// TODO: find the optimal values for these
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 &&
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
return true;
}
return false;
return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 &&
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
}
static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@ -6901,6 +6898,8 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream);
} else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ASSERT(false);
}
@ -7198,7 +7197,12 @@ void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
}
void ggml_cuda_set_scratch_size(const size_t scratch_size) {
g_scratch_size = scratch_size;
// this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously
// it still won't always work as expected, but it's better than nothing
if (scratch_size > g_scratch_size) {
ggml_cuda_free_scratch();
}
g_scratch_size = std::max(g_scratch_size, scratch_size);
}
void ggml_cuda_free_scratch() {

545
llama.cpp

File diff suppressed because it is too large Load diff

84
llama.h
View file

@ -149,32 +149,37 @@ extern "C" {
llama_seq_id all_seq_id; // used if seq_id == NULL
} llama_batch;
struct llama_context_params {
uint32_t seed; // RNG seed, -1 for random
int32_t n_ctx; // text context
int32_t n_batch; // prompt processing batch size
int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors
struct llama_model_params {
int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency
float rope_freq_scale; // RoPE frequency scaling factor
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
void * progress_callback_user_data;
// Keep the booleans together to avoid misalignment during copy-by-value.
bool low_vram; // if true, reduce VRAM usage at the cost of performance
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
};
struct llama_context_params {
uint32_t seed; // RNG seed, -1 for random
uint32_t n_ctx; // text context
uint32_t n_batch; // prompt processing batch size
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency
float rope_freq_scale; // RoPE frequency scaling factor
// Keep the booleans together to avoid misalignment during copy-by-value.
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool embedding; // embedding mode only
};
@ -236,6 +241,7 @@ extern "C" {
};
// Helpers for getting default parameters
LLAMA_API struct llama_model_params llama_model_default_params(void);
LLAMA_API struct llama_context_params llama_context_default_params(void);
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
@ -249,7 +255,7 @@ extern "C" {
LLAMA_API struct llama_model * llama_load_model_from_file(
const char * path_model,
struct llama_context_params params);
struct llama_model_params params);
LLAMA_API void llama_free_model(struct llama_model * model);
@ -266,17 +272,15 @@ extern "C" {
LLAMA_API bool llama_mmap_supported (void);
LLAMA_API bool llama_mlock_supported(void);
LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
LLAMA_API int llama_model_n_embd (const struct llama_model * model);
LLAMA_API int llama_n_vocab (const struct llama_model * model);
LLAMA_API int llama_n_ctx_train(const struct llama_model * model);
LLAMA_API int llama_n_embd (const struct llama_model * model);
// Get a string describing the model type
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
@ -409,8 +413,7 @@ extern "C" {
struct llama_context * ctx,
llama_token * tokens,
int32_t n_tokens,
int n_past,
int n_threads),
int n_past),
"use llama_decode() instead");
// Same as llama_eval, but use float matrix input directly.
@ -419,8 +422,7 @@ extern "C" {
struct llama_context * ctx,
float * embd,
int32_t n_tokens,
int n_past,
int n_threads),
int n_past),
"use llama_decode() instead");
// Return batch for single sequence of tokens starting at pos_0
@ -452,8 +454,12 @@ extern "C" {
// < 0 - error
LLAMA_API int llama_decode(
struct llama_context * ctx,
struct llama_batch batch,
int n_threads);
struct llama_batch batch);
// Set the number of threads used for decoding
// n_threads is the number of threads used for generation (single token)
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
// Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row
@ -494,14 +500,6 @@ extern "C" {
// Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned
LLAMA_API int llama_tokenize(
struct llama_context * ctx,
const char * text,
int text_len,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
LLAMA_API int llama_tokenize_with_model(
const struct llama_model * model,
const char * text,
int text_len,
@ -514,12 +512,6 @@ extern "C" {
// Does not write null terminator to the buffer.
// User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
LLAMA_API int llama_token_to_piece(
const struct llama_context * ctx,
llama_token token,
char * buf,
int length);
LLAMA_API int llama_token_to_piece_with_model(
const struct llama_model * model,
llama_token token,
char * buf,
@ -700,15 +692,13 @@ extern "C" {
/// @param n_beams Number of beams to use.
/// @param n_past Number of tokens already evaluated.
/// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
/// @param n_threads Number of threads as passed to llama_eval().
LLAMA_API void llama_beam_search(
struct llama_context * ctx,
llama_beam_search_callback_fn_t callback,
void * callback_data,
size_t n_beams,
int n_past,
int n_predict,
int n_threads);
int n_predict);
// Performance information
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);

View file

@ -62,18 +62,20 @@ int main(int argc, char **argv) {
// load the vocab
{
auto lparams = llama_context_default_params();
auto mparams = llama_model_default_params();
lparams.vocab_only = true;
mparams.vocab_only = true;
model = llama_load_model_from_file(fname.c_str(), lparams);
model = llama_load_model_from_file(fname.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
return 1;
}
ctx = llama_new_context_with_model(model, lparams);
auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@ -82,7 +84,7 @@ int main(int argc, char **argv) {
}
}
if (llama_vocab_type(ctx) != LLAMA_VOCAB_TYPE_BPE) {
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
fprintf(stderr, "%s : error: vocab type is not SPM\n", __func__);
llama_free_model(model);
llama_free(ctx);

View file

@ -64,18 +64,20 @@ int main(int argc, char **argv) {
// load the vocab
{
auto lparams = llama_context_default_params();
auto mparams = llama_model_default_params();
lparams.vocab_only = true;
mparams.vocab_only = true;
model = llama_load_model_from_file(fname.c_str(), lparams);
model = llama_load_model_from_file(fname.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
return 1;
}
ctx = llama_new_context_with_model(model, lparams);
auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@ -84,7 +86,7 @@ int main(int argc, char **argv) {
}
}
if (llama_vocab_type(ctx) != LLAMA_VOCAB_TYPE_SPM) {
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_SPM) {
fprintf(stderr, "%s : error: vocab type is not SPM\n", __func__);
llama_free_model(model);
llama_free(ctx);

View file

@ -52,18 +52,20 @@ int main(int argc, char **argv) {
// load the vocab
{
auto lparams = llama_context_default_params();
auto mparams = llama_model_default_params();
lparams.vocab_only = true;
mparams.vocab_only = true;
model = llama_load_model_from_file(fname.c_str(), lparams);
model = llama_load_model_from_file(fname.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
return 1;
}
ctx = llama_new_context_with_model(model, lparams);
auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@ -72,7 +74,7 @@ int main(int argc, char **argv) {
}
}
GGML_ASSERT(llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM);
GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
#ifdef _WIN32
// We need this for unicode console support
@ -80,7 +82,7 @@ int main(int argc, char **argv) {
atexit([]() { console::cleanup(); });
#endif
const int n_vocab = llama_n_vocab(ctx);
const int n_vocab = llama_n_vocab(model);
for (int i = 0; i < n_vocab; ++i) {
std::string str = llama_detokenize_spm(ctx, std::vector<int>(1, i));