diff --git a/examples/common.cpp b/examples/common.cpp index 730b28bde..2dc6654da 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -578,18 +578,18 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { auto lparams = llama_context_default_params(); - lparams.n_ctx = params.n_ctx; - lparams.n_batch = params.n_batch; - lparams.n_gpu_layers = params.n_gpu_layers; - lparams.main_gpu = params.main_gpu; - lparams.tensor_split = params.tensor_split; - lparams.low_vram = params.low_vram; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - lparams.logits_all = params.perplexity; - lparams.embedding = params.embedding; + lparams.n_ctx = params.n_ctx; + lparams.n_batch = params.n_batch; + lparams.n_gpu_layers = params.n_gpu_layers; + lparams.main_gpu = params.main_gpu; + lparams.tensor_split = params.tensor_split; + lparams.low_vram = params.low_vram; + lparams.seed = params.seed; + lparams.f16_kv = params.memory_f16; + lparams.use_mmap = params.use_mmap; + lparams.use_mlock = params.use_mlock; + lparams.logits_all = params.perplexity; + lparams.embedding = params.embedding; lparams.rope_freq_base = params.rope_freq_base; lparams.rope_freq_scale = params.rope_freq_scale; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 656382f81..4b4cd1de4 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -139,17 +139,14 @@ int main(int argc, char ** argv) { params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } - // determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters + // determine the maximum memory usage needed to do inference for the given n_batch and n_ctx parameters // uncomment the "used_mem" line in llama.cpp to see the results if (params.mem_test) { { - const std::vector tmp(params.n_batch, llama_token_bos()); - llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); - } + fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx); - { - const std::vector tmp = { 0, }; - llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads); + const std::vector tmp(params.n_batch, llama_token_bos()); + llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads); } llama_print_timings(ctx); diff --git a/llama.cpp b/llama.cpp index 0a381afd5..135aa9fef 100644 --- a/llama.cpp +++ b/llama.cpp @@ -98,18 +98,17 @@ static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * } // -// memory sizes +// memory sizes (calculated for n_batch == 512) // static const std::map & MEM_REQ_SCRATCH0(int n_ctx) { static std::map k_sizes = { - /* empirical scaling, still a guess */ - { MODEL_3B, ((size_t) n_ctx / 16ull + 128ull) * MB }, - { MODEL_7B, ((size_t) n_ctx / 16ull + 256ull) * MB }, - { MODEL_13B, ((size_t) n_ctx / 12ull + 256ull) * MB }, - { MODEL_30B, ((size_t) n_ctx / 10ull + 256ull) * MB }, - { MODEL_65B, ((size_t) n_ctx / 8ull + 512ull) * MB }, + { MODEL_3B, ((size_t) n_ctx / 16ull + 92ull) * MB }, + { MODEL_7B, ((size_t) n_ctx / 16ull + 100ull) * MB }, + { MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB }, + { MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB }, + { MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess }; return k_sizes; } @@ -117,38 +116,24 @@ static const std::map & MEM_REQ_SCRATCH0(int n_ctx) static const std::map & MEM_REQ_SCRATCH1() { static std::map k_sizes = { - { MODEL_3B, 256ull * MB }, - { MODEL_7B, 512ull * MB }, - { MODEL_13B, 512ull * MB }, - { MODEL_30B, 512ull * MB }, - { MODEL_65B, 1024ull * MB }, + { MODEL_3B, 128ull * MB }, + { MODEL_7B, 160ull * MB }, + { MODEL_13B, 192ull * MB }, + { MODEL_30B, 256ull * MB }, + { MODEL_65B, 384ull * MB }, // guess }; return k_sizes; } -// 2*n_embd*n_ctx*n_layer*sizeof(float16) -static const std::map & MEM_REQ_KV_SELF() +// used to store the compute graph tensors + non-scratch data +static const std::map & MEM_REQ_EVAL() { static std::map k_sizes = { - { MODEL_3B, 682ull * MB }, - { MODEL_7B, 1026ull * MB }, - { MODEL_13B, 1608ull * MB }, - { MODEL_30B, 3124ull * MB }, - { MODEL_65B, 5120ull * MB }, - }; - return k_sizes; -} - -// this is mostly needed for temporary mul_mat buffers to dequantize the data -// not actually needed if BLAS is disabled -static const std::map & MEM_REQ_EVAL(int n_ctx) -{ - static std::map k_sizes = { - { MODEL_3B, ((size_t) n_ctx / 256ull + 512ull) * MB }, - { MODEL_7B, ((size_t) n_ctx / 256ull + 768ull) * MB }, - { MODEL_13B, ((size_t) n_ctx / 256ull + 1024ull) * MB }, - { MODEL_30B, ((size_t) n_ctx / 256ull + 1280ull) * MB }, - { MODEL_65B, ((size_t) n_ctx / 256ull + 1536ull) * MB }, + { MODEL_3B, 8ull * MB }, + { MODEL_7B, 10ull * MB }, + { MODEL_13B, 12ull * MB }, + { MODEL_30B, 16ull * MB }, + { MODEL_65B, 24ull * MB }, // guess }; return k_sizes; } @@ -199,6 +184,15 @@ struct llama_hparams { bool operator!=(const llama_hparams & other) const { return static_cast(memcmp(this, &other, sizeof(llama_hparams))); } + + size_t kv_size() const { + size_t result = 2ull; + result *= (size_t) n_embd; + result *= (size_t) n_ctx; + result *= (size_t) n_layer; + result *= sizeof(ggml_fp16_t); + return result; + } }; struct llama_layer { @@ -1069,7 +1063,7 @@ static void llama_model_load_internal( { model.buf.resize(ctx_size); if (use_mlock) { - model.mlock_buf.init(model.buf.addr); + model.mlock_buf.init (model.buf.addr); model.mlock_buf.grow_to(model.buf.size); } @@ -1186,11 +1180,11 @@ static void llama_model_load_internal( mmapped_size - vram_weights + // weights in VRAM not in memory MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) + MEM_REQ_SCRATCH1().at(model.type) + - MEM_REQ_EVAL(hparams.n_ctx).at(model.type); + MEM_REQ_EVAL().at(model.type); // this is the memory required by one llama_state const size_t mem_required_state = - scale*MEM_REQ_KV_SELF().at(model.type); + scale*hparams.kv_size(); fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); @@ -1231,7 +1225,7 @@ static void llama_model_load_internal( fprintf(stderr, "%s: cannot offload v cache to GPU due to low VRAM option\n", __func__); } else { fprintf(stderr, "%s: offloading v cache to GPU\n", __func__); - vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; + vram_kv_cache += hparams.kv_size() / 2; } } if (n_gpu_layers > (int) hparams.n_layer + 2) { @@ -1239,7 +1233,7 @@ static void llama_model_load_internal( fprintf(stderr, "%s: cannot offload k cache to GPU due to low VRAM option\n", __func__); } else { fprintf(stderr, "%s: offloading k cache to GPU\n", __func__); - vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; + vram_kv_cache += hparams.kv_size() / 2; } } #elif defined(GGML_USE_CLBLAST) @@ -1739,10 +1733,12 @@ static bool llama_eval_internal( } #if 0 - printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__, + printf("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, lctx.get_buf_max_mem(0)/1024.0/1024.0, - lctx.get_buf_max_mem(1)/1024.0/1024.0); + lctx.get_buf_max_mem(1)/1024.0/1024.0, + lctx.work_buffer.size()/1024.0/1024.0, + n_past, N); #endif ggml_free(ctx0); @@ -2448,8 +2444,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; - case LLAMA_FTYPE_MOSTLY_F16: quantized_type = GGML_TYPE_F16; break; - case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break; + case LLAMA_FTYPE_MOSTLY_F16: quantized_type = GGML_TYPE_F16; break; + case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break; #ifdef GGML_USE_K_QUANTS // K-quants @@ -2533,16 +2529,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } else { new_type = quantized_type; #ifdef GGML_USE_K_QUANTS - bool convert_incompatible_tensor = false; - if (quantized_type == GGML_TYPE_Q2_K || quantized_type == GGML_TYPE_Q3_K || quantized_type == GGML_TYPE_Q4_K || - quantized_type == GGML_TYPE_Q5_K || quantized_type == GGML_TYPE_Q6_K) { - int nx = tensor.ne.at(0); - int ny = tensor.ne.at(1); - if (nx % QK_K != 0 || ny % QK_K != 0) { - fprintf(stderr, "\n\nTensor sizes %d x %d are not divisible by %d, required for k-quants.\n",nx,ny,QK_K); - convert_incompatible_tensor = true; - } - } if (tensor.name == "output.weight") { int nx = tensor.ne.at(0); int ny = tensor.ne.at(1); @@ -2568,6 +2554,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; } + bool convert_incompatible_tensor = false; + if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || + new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) { + int nx = tensor.ne.at(0); + int ny = tensor.ne.at(1); + if (nx % QK_K != 0 || ny % QK_K != 0) { + fprintf(stderr, "\n\nTensor sizes %d x %d are not divisible by %d, required for k-quants.\n",nx,ny,QK_K); + convert_incompatible_tensor = true; + } + } if (convert_incompatible_tensor) { if (tensor.name == "output.weight") { new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing. @@ -2594,7 +2590,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s f32_data = (float *) f32_conv_buf.addr; } - printf("quantizing .. "); + printf("quantizing to %s .. ", ggml_type_name(new_type)); fflush(stdout); work.resize(nelements * 4); // upper bound on size @@ -2775,7 +2771,7 @@ struct llama_context * llama_new_context_with_model( ctx->embedding.resize(hparams.n_embd); } - ctx->buf_compute.resize(MEM_REQ_EVAL(hparams.n_ctx).at(ctx->model.type)); + ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type)); ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type)); ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));