llama : per-layer KV cache + quantum K cache (#4309)

* per-layer KV

* remove unnecessary copies

* less code duplication, offload k and v separately

* llama : offload KV cache per-layer

* llama : offload K shift tensors

* llama : offload for rest of the model arches

* llama : enable offload debug temporarily

* llama : keep the KV related layers on the device

* llama : remove mirrors, perform Device -> Host when partial offload

* common : add command-line arg to disable KV cache offloading

* llama : update session save/load

* llama : support quantum K cache (#4312)

* llama : support quantum K cache (wip)

* metal : add F32 -> Q8_0 copy kernel

* cuda : add F32 -> Q8_0 copy kernel

ggml-ci

* cuda : use mmv kernel for quantum cache ops

* llama : pass KV cache type through API

* llama : fix build

ggml-ci

* metal : add F32 -> Q4_0 copy kernel

* metal : add F32 -> Q4_1 copy kernel

* cuda : wip

* cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels

* llama-bench : support type_k/type_v

* metal : use mm kernel only for quantum KV cache

* cuda : add comment

* llama : remove memory_f16 and kv_f16 flags

---------

Co-authored-by: slaren <slarengh@gmail.com>

* readme : add API change notice

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2023-12-07 13:03:17 +02:00 committed by GitHub
parent 81bc9214a3
commit bcc0eb4591
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 747 additions and 287 deletions

View file

@ -10,6 +10,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
### Hot topics
- **llama.h API change for handling KV cache offloading and data type: https://github.com/ggerganov/llama.cpp/pull/4309**
- Using `llama.cpp` with AWS instances: https://github.com/ggerganov/llama.cpp/discussions/4225
- Looking for contributions to improve and maintain the `server` example: https://github.com/ggerganov/llama.cpp/issues/4216
- Collecting Apple Silicon performance stats: https://github.com/ggerganov/llama.cpp/discussions/4167

View file

@ -278,8 +278,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
} else if (arg == "--samplers") {
if (++i >= argc) {
invalid_param = true;
@ -510,6 +508,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.infill = true;
} else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
params.dump_kv_cache = true;
} else if (arg == "-nkvo" || arg == "--no-kv-offload") {
params.no_kv_offload = true;
} else if (arg == "-ctk" || arg == "--cache-type-k") {
params.cache_type_k = argv[++i];
} else if (arg == "-ctv" || arg == "--cache-type-v") {
params.cache_type_v = argv[++i];
} else if (arg == "--multiline-input") {
params.multiline_input = true;
} else if (arg == "--simple-io") {
@ -858,8 +862,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
printf(" --no-penalize-nl do not penalize newline token\n");
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
@ -900,6 +902,12 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --verbose-prompt print prompt before generation\n");
printf(" -dkvc, --dump-kv-cache\n");
printf(" verbose print of the KV cache\n");
printf(" -nkvo, --no-kv-offload\n");
printf(" disable KV offload\n");
printf(" -ctk TYPE, --cache-type-k TYPE\n");
printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str());
printf(" -ctv TYPE, --cache-type-v TYPE\n");
printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str());
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
@ -1015,6 +1023,29 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
return mparams;
}
static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "f16") {
return GGML_TYPE_F16;
}
if (s == "q8_0") {
return GGML_TYPE_Q8_0;
}
if (s == "q4_0") {
return GGML_TYPE_Q4_0;
}
if (s == "q4_1") {
return GGML_TYPE_Q4_1;
}
if (s == "q5_0") {
return GGML_TYPE_Q5_0;
}
if (s == "q5_1") {
return GGML_TYPE_Q5_1;
}
throw std::runtime_error("Invalid cache type: " + s);
}
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto cparams = llama_context_default_params();
@ -1024,7 +1055,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.mul_mat_q = params.mul_mat_q;
cparams.seed = params.seed;
cparams.f16_kv = params.memory_f16;
cparams.logits_all = params.logits_all;
cparams.embedding = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
@ -1035,6 +1065,10 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.offload_kqv = !params.no_kv_offload;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
return cparams;
}
@ -1447,7 +1481,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
}
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);

View file

@ -100,7 +100,6 @@ struct gpt_params {
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode
@ -125,10 +124,14 @@ struct gpt_params {
bool verbose_prompt = false; // print prompt tokens before generation
bool infill = false; // use infill mode
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
bool no_kv_offload = false; // disable KV offloading
std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V
// multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector
std::string image = ""; // path to an image file
std::string image = ""; // path to an image file
};
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);

View file

@ -53,6 +53,13 @@ static std::vector<T> split(const std::string & str, char delim) {
return values;
}
template<typename T, typename F>
static std::vector<std::string> transform_to_str(const std::vector<T> & values, F f) {
std::vector<std::string> str_values;
std::transform(values.begin(), values.end(), std::back_inserter(str_values), f);
return str_values;
}
template<typename T>
static T avg(const std::vector<T> & v) {
if (v.empty()) {
@ -126,7 +133,8 @@ struct cmd_params {
std::vector<int> n_prompt;
std::vector<int> n_gen;
std::vector<int> n_batch;
std::vector<bool> f32_kv;
std::vector<ggml_type> type_k;
std::vector<ggml_type> type_v;
std::vector<int> n_threads;
std::vector<int> n_gpu_layers;
std::vector<int> main_gpu;
@ -142,7 +150,8 @@ static const cmd_params cmd_params_defaults = {
/* n_prompt */ {512},
/* n_gen */ {128},
/* n_batch */ {512},
/* f32_kv */ {false},
/* type_k */ {GGML_TYPE_F16},
/* type_v */ {GGML_TYPE_F16},
/* n_threads */ {get_num_physical_cores()},
/* n_gpu_layers */ {99},
/* main_gpu */ {0},
@ -162,7 +171,8 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -p, --n-prompt <n> (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
printf(" -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
printf(" --memory-f32 <0|1> (default: %s)\n", join(cmd_params_defaults.f32_kv, ",").c_str());
printf(" -ctk <t>, --cache-type-k <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
printf(" -ctv <t>, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
@ -173,9 +183,32 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
printf("\n");
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
}
static ggml_type ggml_type_from_name(const std::string & s) {
if (s == "f16") {
return GGML_TYPE_F16;
}
if (s == "q8_0") {
return GGML_TYPE_Q8_0;
}
if (s == "q4_0") {
return GGML_TYPE_Q4_0;
}
if (s == "q4_1") {
return GGML_TYPE_Q4_1;
}
if (s == "q5_0") {
return GGML_TYPE_Q5_0;
}
if (s == "q5_1") {
return GGML_TYPE_Q5_1;
}
return GGML_TYPE_COUNT;
}
static cmd_params parse_cmd_params(int argc, char ** argv) {
cmd_params params;
std::string arg;
@ -224,13 +257,38 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = split<int>(argv[i], split_delim);
params.n_batch.insert(params.n_batch.end(), p.begin(), p.end());
} else if (arg == "--memory-f32") {
} else if (arg == "-ctk" || arg == "--cache-type-k") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = split<int>(argv[i], split_delim);
params.f32_kv.insert(params.f32_kv.end(), p.begin(), p.end());
auto p = split<std::string>(argv[i], split_delim);
std::vector<ggml_type> types;
for (const auto & t : p) {
ggml_type gt = ggml_type_from_name(t);
if (gt == GGML_TYPE_COUNT) {
invalid_param = true;
break;
}
types.push_back(gt);
}
params.type_k.insert(params.type_k.end(), types.begin(), types.end());
} else if (arg == "-ctv" || arg == "--cache-type-v") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = split<std::string>(argv[i], split_delim);
std::vector<ggml_type> types;
for (const auto & t : p) {
ggml_type gt = ggml_type_from_name(t);
if (gt == GGML_TYPE_COUNT) {
invalid_param = true;
break;
}
types.push_back(gt);
}
params.type_v.insert(params.type_v.end(), types.begin(), types.end());
} else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_param = true;
@ -321,7 +379,8 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; }
if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; }
if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; }
if (params.f32_kv.empty()) { params.f32_kv = cmd_params_defaults.f32_kv; }
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; }
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
if (params.mul_mat_q.empty()) { params.mul_mat_q = cmd_params_defaults.mul_mat_q; }
@ -336,7 +395,8 @@ struct cmd_params_instance {
int n_prompt;
int n_gen;
int n_batch;
bool f32_kv;
ggml_type type_k;
ggml_type type_v;
int n_threads;
int n_gpu_layers;
int main_gpu;
@ -365,7 +425,8 @@ struct cmd_params_instance {
cparams.n_ctx = n_prompt + n_gen;
cparams.n_batch = n_batch;
cparams.f16_kv = !f32_kv;
cparams.type_k = type_k;
cparams.type_v = type_v;
cparams.mul_mat_q = mul_mat_q;
return cparams;
@ -380,7 +441,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
for (const auto & mg : params.main_gpu)
for (const auto & ts : params.tensor_split)
for (const auto & nb : params.n_batch)
for (const auto & fk : params.f32_kv)
for (const auto & tk : params.type_k)
for (const auto & tv : params.type_v)
for (const auto & mmq : params.mul_mat_q)
for (const auto & nt : params.n_threads) {
cmd_params_instance instance = {
@ -388,7 +450,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances_int(const cmd_p
/* .n_prompt = */ n_prompt,
/* .n_gen = */ n_gen,
/* .n_batch = */ nb,
/* .f32_kv = */ fk,
/* .type_k = */ tk,
/* .type_v = */ tv,
/* .n_threads = */ nt,
/* .n_gpu_layers = */ nl,
/* .main_gpu = */ mg,
@ -410,7 +473,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & mg : params.main_gpu)
for (const auto & ts : params.tensor_split)
for (const auto & nb : params.n_batch)
for (const auto & fk : params.f32_kv)
for (const auto & tk : params.type_k)
for (const auto & tv : params.type_v)
for (const auto & mmq : params.mul_mat_q)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
@ -422,7 +486,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .n_prompt = */ n_prompt,
/* .n_gen = */ 0,
/* .n_batch = */ nb,
/* .f32_kv = */ fk,
/* .type_k = */ tk,
/* .type_v = */ tv,
/* .n_threads = */ nt,
/* .n_gpu_layers = */ nl,
/* .main_gpu = */ mg,
@ -441,7 +506,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .n_prompt = */ 0,
/* .n_gen = */ n_gen,
/* .n_batch = */ nb,
/* .f32_kv = */ fk,
/* .type_k = */ tk,
/* .type_v = */ tv,
/* .n_threads = */ nt,
/* .n_gpu_layers = */ nl,
/* .main_gpu = */ mg,
@ -489,7 +555,8 @@ struct test {
uint64_t model_n_params;
int n_batch;
int n_threads;
bool f32_kv;
ggml_type type_k;
ggml_type type_v;
int n_gpu_layers;
int main_gpu;
bool mul_mat_q;
@ -508,7 +575,8 @@ struct test {
model_n_params = llama_model_n_params(lmodel);
n_batch = inst.n_batch;
n_threads = inst.n_threads;
f32_kv = inst.f32_kv;
type_k = inst.type_k;
type_v = inst.type_v;
n_gpu_layers = inst.n_gpu_layers;
main_gpu = inst.main_gpu;
mul_mat_q = inst.mul_mat_q;
@ -571,7 +639,7 @@ struct test {
"cuda", "opencl", "metal", "gpu_blas", "blas",
"cpu_info", "gpu_info",
"model_filename", "model_type", "model_size", "model_n_params",
"n_batch", "n_threads", "f16_kv",
"n_batch", "n_threads", "type_k", "type_v",
"n_gpu_layers", "main_gpu", "mul_mat_q", "tensor_split",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
@ -621,7 +689,7 @@ struct test {
std::to_string(cuda), std::to_string(opencl), std::to_string(metal), std::to_string(gpu_blas), std::to_string(blas),
cpu_info, gpu_info,
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
std::to_string(n_batch), std::to_string(n_threads), std::to_string(!f32_kv),
std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), tensor_split_str,
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@ -805,8 +873,11 @@ struct markdown_printer : public printer {
if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
fields.push_back("n_batch");
}
if (params.f32_kv.size() > 1 || params.f32_kv != cmd_params_defaults.f32_kv) {
fields.push_back("f16_kv");
if (params.type_k.size() > 1 || params.type_k != cmd_params_defaults.type_k) {
fields.push_back("type_k");
}
if (params.type_v.size() > 1 || params.type_v != cmd_params_defaults.type_v) {
fields.push_back("type_v");
}
if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) {
fields.push_back("main_gpu");

View file

@ -321,7 +321,6 @@ int main(int argc, char ** argv) {
auto cparams = llama_context_default_params();
cparams.n_ctx = 256;
cparams.seed = 1;
cparams.f16_kv = false;
ctx = llama_new_context_with_model(model, cparams);

View file

@ -2108,10 +2108,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.yarn_beta_slow = std::stof(argv[i]);
}
else if (arg == "--memory-f32" || arg == "--memory_f32")
{
params.memory_f16 = false;
}
else if (arg == "--threads" || arg == "-t")
{
if (++i >= argc)

View file

@ -7,6 +7,7 @@
#include <stdio.h>
#include <atomic>
#include <assert.h>
#include <float.h>
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
@ -4559,6 +4560,116 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset);
}
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q8_0 * dsti = (block_q8_0 *) cdsti;
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = xi[j];
amax = fmaxf(amax, fabsf(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = xi[j]*id;
dsti->qs[j] = roundf(x0);
}
}
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_0 * dsti = (block_q4_0 *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_0; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_0/2 + j]*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
dsti->qs[j] = xi0;
dsti->qs[j] |= xi1 << 4;
}
}
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_1 * dsti = (block_q4_1 *) cdsti;
float vmin = FLT_MAX;
float vmax = -FLT_MAX;
for (int j = 0; j < QK4_1; ++j) {
const float v = xi[j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
dsti->dm.x = d;
dsti->dm.y = vmin;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (xi[0 + j] - vmin)*id;
const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
dsti->qs[j] = xi0;
dsti->qs[j] |= xi1 << 4;
}
}
template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) {
return;
}
const int i02 = i / (ne00*ne01);
const int i01 = (i - i02*ne01*ne00) / ne00;
const int i00 = (i - i02*ne01*ne00 - i01*ne00);
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
const int i12 = i / (ne10*ne11);
const int i11 = (i - i12*ne10*ne11) / ne10;
const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk;
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
cpy_blck(cx + x_offset, cdst + dst_offset);
}
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
@ -5737,6 +5848,39 @@ static void ggml_cpy_f32_f16_cuda(
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
static void ggml_cpy_f32_q8_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
GGML_ASSERT(ne % QK8_0 == 0);
const int num_blocks = ne / QK8_0;
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
static void ggml_cpy_f32_q4_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_0 == 0);
const int num_blocks = ne / QK4_0;
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
static void ggml_cpy_f32_q4_1_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_1 == 0);
const int num_blocks = ne / QK4_1;
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
}
static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@ -6093,20 +6237,21 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
const enum ggml_type type = src->type;
const int64_t ts = ggml_type_size(type);
const int64_t bs = ggml_blck_size(type);
int64_t i1_diff = i1_high - i1_low;
const int64_t i1_diff = i1_high - i1_low;
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
if (nb0 == ts && nb1 == ts*ne0/bs) {
if (nb0 == ts && nb1 == ts*(ne0/bs)) {
return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream);
}
if (nb0 == ts) {
return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
return cudaMemcpy2DAsync(dst_ptr, ts*(ne0/bs), x, nb1, ts*(ne0/bs), i1_diff, kind, stream);
}
GGML_ASSERT(bs == 1 && "TODO: implement bs != 1");
for (int64_t i1 = 0; i1 < i1_diff; i1++) {
const void * rx = (const void *) ((const char *) x + i1*nb1);
void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
void * rd = (void *) (dst_ptr + i1*ts*ne0);
// pretend the row is a matrix with cols=1
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
cudaError_t r = cudaMemcpy2DAsync(rd, ts, rx, nb0, ts, ne0, kind, stream);
if (r != cudaSuccess) { return r; }
}
return cudaSuccess;
@ -6474,6 +6619,8 @@ inline void ggml_cuda_op_mul_mat_vec_q(
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, const cudaStream_t & stream) {
GGML_ASSERT(ggml_nrows(src1) == 1);
const int64_t ne00 = src0->ne[0];
const int64_t row_diff = row_high - row_low;
@ -6533,7 +6680,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
size_t ash;
dfloat * src1_dfloat = nullptr; // dfloat == half
bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
bool src1_convert_f16 =
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
@ -7103,10 +7251,9 @@ static void ggml_cuda_op_mul_mat(
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
const bool src0_is_contiguous = ggml_is_contiguous(src0);
const bool src1_is_contiguous = ggml_is_contiguous(src1);
const int64_t src1_padded_col_size = ne10 % MATRIX_ROW_PADDING == 0 ?
ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING;
const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
GGML_ASSERT(!(split && ne02 > 1));
@ -7231,7 +7378,7 @@ static void ggml_cuda_op_mul_mat(
const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
// for split tensors the data begins at i0 == i0_offset_low
char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * ne01*ne00*src0_ts/src0_bs;
char * src0_dd_i = src0_dd[id] + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
float * src1_ddf_i = src1_ddf[id] + (i0*ne11 + src1_col_0) * ne10;
char * src1_ddq_i = src1_ddq[id] + src1_ddq_i_offset;
float * dst_dd_i = dst_dd[id] + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
@ -7698,10 +7845,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
#ifdef GGML_CUDA_FORCE_DMMV
const bool use_mul_mat_vec_q = false;
#else
const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1;
#endif // GGML_CUDA_FORCE_DMMV
if (use_mul_mat_vec_q) {
// NOTE: this kernel does not support ggml_nrows(src1) > 1
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
} else {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
@ -7770,14 +7918,17 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream);
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream);
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
ne10, ne11, nb10, nb11, nb12, main_stream);
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
} else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@ -7788,6 +7939,7 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
}
static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
// TODO: why do we pass dst as src1 here?
ggml_cuda_cpy(src0, dst, nullptr);
(void) src1;
}

View file

@ -118,6 +118,11 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(im2col_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
GGML_METAL_DECL_KERNEL(concat);
GGML_METAL_DECL_KERNEL(sqr);
@ -324,6 +329,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(im2col_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
GGML_METAL_ADD_KERNEL(concat);
GGML_METAL_ADD_KERNEL(sqr);
@ -425,6 +435,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(im2col_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
GGML_METAL_DEL_KERNEL(concat);
GGML_METAL_DEL_KERNEL(sqr);
@ -1114,7 +1129,7 @@ void ggml_metal_graph_compute(
!ggml_is_transposed(src1) &&
src1t == GGML_TYPE_F32 &&
ne00 % 32 == 0 && ne00 >= 64 &&
ne11 > ne11_mm_min) {
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
switch (src0->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
@ -1549,14 +1564,23 @@ void ggml_metal_graph_compute(
case GGML_OP_CPY:
case GGML_OP_CONT:
{
const int nth = MIN(1024, ne00);
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
switch (src0t) {
case GGML_TYPE_F32:
{
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
//case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
//case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;

View file

@ -3,6 +3,7 @@
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define QK4_0 32
#define QR4_0 2
@ -1460,6 +1461,197 @@ kernel void kernel_cpy_f32_f32(
}
}
kernel void kernel_cpy_f32_q8_0(
device const float * src0,
device void * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = src[j];
amax = MAX(amax, fabs(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst_data[i00/QK8_0].d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = src[j]*id;
dst_data[i00/QK8_0].qs[j] = round(x0);
}
}
}
kernel void kernel_cpy_f32_q4_0(
device const float * src0,
device void * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
float amax = 0.0f; // absolute max
float max = 0.0f;
for (int j = 0; j < QK4_0; j++) {
const float v = src[j];
if (amax < fabs(v)) {
amax = fabs(v);
max = v;
}
}
const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f;
dst_data[i00/QK4_0].d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = src[0 + j]*id;
const float x1 = src[QK4_0/2 + j]*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
dst_data[i00/QK4_0].qs[j] = xi0;
dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
}
}
}
kernel void kernel_cpy_f32_q4_1(
device const float * src0,
device void * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
float min = FLT_MAX;
float max = -FLT_MAX;
for (int j = 0; j < QK4_1; j++) {
const float v = src[j];
if (min > v) min = v;
if (max < v) max = v;
}
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
dst_data[i00/QK4_1].d = d;
dst_data[i00/QK4_1].m = min;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (src[0 + j] - min)*id;
const float x1 = (src[QK4_1/2 + j] - min)*id;
const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
dst_data[i00/QK4_1].qs[j] = xi0;
dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
}
}
}
kernel void kernel_concat(
device const char * src0,
device const char * src1,

440
llama.cpp
View file

@ -1231,6 +1231,7 @@ struct llama_cparams {
float yarn_beta_slow;
bool mul_mat_q;
bool offload_kqv;
};
struct llama_layer {
@ -1299,8 +1300,8 @@ struct llama_kv_cache {
std::vector<llama_kv_cell> cells;
struct ggml_tensor * k = NULL;
struct ggml_tensor * v = NULL;
std::vector<struct ggml_tensor *> k_l; // per layer
std::vector<struct ggml_tensor *> v_l;
struct ggml_context * ctx = NULL;
@ -1313,8 +1314,10 @@ struct llama_kv_cache {
#ifdef GGML_USE_CUBLAS
if (ggml_cublas_loaded()) {
ggml_cuda_free_data(k);
ggml_cuda_free_data(v);
for (size_t i = 0; i < k_l.size(); ++i) {
ggml_cuda_free_data(k_l[i]);
ggml_cuda_free_data(v_l[i]);
}
}
#endif
}
@ -1504,9 +1507,11 @@ struct llama_context {
static bool llama_kv_cache_init(
const struct llama_hparams & hparams,
struct llama_kv_cache & cache,
ggml_type wtype,
ggml_type ktype,
ggml_type vtype,
uint32_t n_ctx,
int n_gpu_layers) {
int n_gpu_layers,
bool offload) {
const uint32_t n_embd = hparams.n_embd_gqa();
const uint32_t n_layer = hparams.n_layer;
@ -1522,7 +1527,7 @@ static bool llama_kv_cache_init(
cache.cells.clear();
cache.cells.resize(n_ctx);
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
cache.buf.resize(n_elements*(ggml_type_sizef(ktype) + ggml_type_sizef(vtype)) + 2u*n_layer*ggml_tensor_overhead());
memset(cache.buf.data, 0, cache.buf.size);
struct ggml_init_params params;
@ -1532,37 +1537,44 @@ static bool llama_kv_cache_init(
cache.ctx = ggml_init(params);
size_t vram_kv_cache = 0;
if (!cache.ctx) {
LLAMA_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__);
return false;
}
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
ggml_set_name(cache.k, "cache_k");
ggml_set_name(cache.v, "cache_v");
cache.k_l.reserve(n_layer);
cache.v_l.reserve(n_layer);
(void) n_gpu_layers;
const int i_gpu_start = (int) n_layer - n_gpu_layers; GGML_UNUSED(i_gpu_start);
GGML_UNUSED(offload);
for (int i = 0; i < (int) n_layer; i++) {
ggml_tensor * k = ggml_new_tensor_1d(cache.ctx, ktype, n_embd*n_ctx);
ggml_tensor * v = ggml_new_tensor_1d(cache.ctx, vtype, n_embd*n_ctx);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
cache.k_l.push_back(k);
cache.v_l.push_back(v);
#ifdef GGML_USE_CUBLAS
if (ggml_cublas_loaded()) {
size_t vram_kv_cache = 0;
if (n_gpu_layers > (int)n_layer + 1) {
ggml_cuda_assign_buffers_no_scratch(cache.v);
LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__);
vram_kv_cache += ggml_nbytes(cache.v);
}
if (n_gpu_layers > (int)n_layer + 2) {
ggml_cuda_assign_buffers_no_scratch(cache.k);
LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__);
vram_kv_cache += ggml_nbytes(cache.k);
}
if (vram_kv_cache > 0) {
LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MiB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
if (i >= i_gpu_start) {
if (offload) {
ggml_cuda_assign_buffers_no_scratch(k);
vram_kv_cache += ggml_nbytes(k);
ggml_cuda_assign_buffers_no_scratch(v);
vram_kv_cache += ggml_nbytes(v);
}
}
#endif // GGML_USE_CUBLAS
}
#endif
if (vram_kv_cache > 0) {
LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0);
}
GGML_UNUSED(n_gpu_layers);
return true;
}
@ -2968,14 +2980,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3045,14 +3050,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3115,14 +3113,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3192,14 +3183,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3269,21 +3253,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > int(n_layer + 1)) {
LLAMA_LOG_ERROR("%s: CUDA backend missing Persimmon CUDA ops, can offload at most %ld layers. See: https://github.com/ggerganov/llama.cpp/issues/4038\n",
__func__, n_layer + 1);
throw std::runtime_error("Persimmon CUDA offload failed");
}
#endif
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3342,14 +3312,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3420,14 +3383,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3487,14 +3443,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3559,14 +3508,7 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = llama_backend_offload;
#else
backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
#endif // _WIN32
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
@ -3642,8 +3584,8 @@ static void llm_load_tensors(
}
#ifdef GGML_USE_CUBLAS
const int max_backend_supported_layers = hparams.n_layer + 3;
const int max_offloadable_layers = hparams.n_layer + 3;
const int max_backend_supported_layers = hparams.n_layer + 1;
const int max_offloadable_layers = hparams.n_layer + 1;
#elif GGML_USE_CLBLAST
const int max_backend_supported_layers = hparams.n_layer + 1;
const int max_offloadable_layers = hparams.n_layer + 1;
@ -3811,11 +3753,11 @@ static void llm_build_k_shift(
struct ggml_tensor * tmp =
// we rotate only the first n_rot dimensions
ggml_rope_custom_inplace(ctx,
ggml_view_3d(ctx, kv.k,
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head, n_head_kv, n_ctx,
ggml_element_size(kv.k)*n_embd_head,
ggml_element_size(kv.k)*n_embd_gqa,
ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
0),
K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted", il);
@ -3842,13 +3784,13 @@ static void llm_build_kv_store(
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
cb(v_cur_t, "v_cur_t", il);
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k, n_tokens*n_embd_gqa,
(ggml_element_size(kv.k)*n_embd_gqa)*(il*n_ctx + kv_head));
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_gqa,
(ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa)*kv_head);
cb(k_cache_view, "k_cache_view", il);
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v, n_tokens, n_embd_gqa,
( n_ctx)*ggml_element_size(kv.v),
(il*n_ctx)*ggml_element_size(kv.v)*n_embd_gqa + kv_head*ggml_element_size(kv.v));
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_gqa,
( n_ctx)*ggml_element_size(kv.v_l[il]),
(kv_head)*ggml_element_size(kv.v_l[il]));
cb(v_cache_view, "v_cache_view", il);
// important: storing RoPE-ed version of K in the KV cache!
@ -4000,11 +3942,11 @@ static struct ggml_tensor * llm_build_kqv(
cb(q, "q", il);
struct ggml_tensor * k =
ggml_view_3d(ctx, kv.k,
ggml_view_3d(ctx, kv.k_l[il],
n_embd_head, n_kv, n_head_kv,
ggml_element_size(kv.k)*n_embd_gqa,
ggml_element_size(kv.k)*n_embd_head,
ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il);
ggml_type_sizef(kv.k_l[il]->type)*n_embd_gqa,
ggml_type_sizef(kv.k_l[il]->type)*n_embd_head,
0);
cb(k, "k", il);
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
@ -4035,11 +3977,11 @@ static struct ggml_tensor * llm_build_kqv(
// split cached v into n_head heads
struct ggml_tensor * v =
ggml_view_3d(ctx, kv.v,
ggml_view_3d(ctx, kv.v_l[il],
n_kv, n_embd_head, n_head_kv,
ggml_element_size(kv.v)*n_ctx,
ggml_element_size(kv.v)*n_ctx*n_embd_head,
ggml_element_size(kv.v)*n_ctx*n_embd_gqa*il);
ggml_element_size(kv.v_l[il])*n_ctx,
ggml_element_size(kv.v_l[il])*n_ctx*n_embd_head,
0);
cb(v, "v", il);
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
@ -4631,6 +4573,7 @@ struct llm_build_context {
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
cb(inpL, "imp_embd", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
@ -4638,6 +4581,7 @@ struct llm_build_context {
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
@ -5237,15 +5181,15 @@ struct llm_build_context {
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * inp_pos= ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
struct ggml_tensor * KQ_scale= ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
cb(KQ_scale, "KQ_scale", -1);
// KQ_mask (mask for 1 head, it wil be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask= ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
// shift the entire K-cache if needed
@ -5351,8 +5295,8 @@ struct llm_build_context {
enum llm_offload_func_e {
OFFLOAD_FUNC_NOP,
OFFLOAD_FUNC,
OFFLOAD_FUNC_KQ,
OFFLOAD_FUNC_V,
OFFLOAD_FUNC_FRC, // force offload
OFFLOAD_FUNC_KQV,
OFFLOAD_FUNC_NR,
OFFLOAD_FUNC_EMB,
OFFLOAD_FUNC_OUT,
@ -5438,11 +5382,12 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
//{ "inp_embd", OFFLOAD_FUNC_NR }, // TODO: missing K-quants get_rows kernel
{ "pos_embd", OFFLOAD_FUNC_NR },
{ "inp_pos", OFFLOAD_FUNC_KQ }, // this is often used for KQ ops (e.g. rope)
{ "KQ_scale", OFFLOAD_FUNC_KQ },
{ "KQ_mask", OFFLOAD_FUNC_KQ },
{ "K_shift", OFFLOAD_FUNC_KQ },
{ "K_shifted", OFFLOAD_FUNC_KQ },
{ "inp_pos", OFFLOAD_FUNC_FRC }, // this is often used for KQ ops (e.g. rope)
{ "KQ_scale", OFFLOAD_FUNC_FRC },
{ "KQ_mask", OFFLOAD_FUNC_FRC },
{ "K_shift", OFFLOAD_FUNC_FRC },
{ "K_shifted", OFFLOAD_FUNC },
{ "inp_norm", OFFLOAD_FUNC_NR },
{ "inp_norm_w", OFFLOAD_FUNC_NR },
@ -5455,38 +5400,38 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
{ "attn_norm", OFFLOAD_FUNC },
{ "attn_norm_2", OFFLOAD_FUNC },
{ "wqkv", OFFLOAD_FUNC_KQ },
{ "bqkv", OFFLOAD_FUNC_KQ },
{ "wqkv_clamped", OFFLOAD_FUNC_KQ },
{ "wqkv", OFFLOAD_FUNC_KQV },
{ "bqkv", OFFLOAD_FUNC_KQV },
{ "wqkv_clamped", OFFLOAD_FUNC_KQV },
{ "tmpk", OFFLOAD_FUNC_KQ },
{ "tmpq", OFFLOAD_FUNC_KQ },
{ "tmpv", OFFLOAD_FUNC_V },
{ "Kcur", OFFLOAD_FUNC_KQ },
{ "Qcur", OFFLOAD_FUNC_KQ },
{ "Vcur", OFFLOAD_FUNC_V },
{ "tmpk", OFFLOAD_FUNC_KQV },
{ "tmpq", OFFLOAD_FUNC_KQV },
{ "tmpv", OFFLOAD_FUNC_KQV },
{ "Kcur", OFFLOAD_FUNC_KQV },
{ "Qcur", OFFLOAD_FUNC_KQV },
{ "Vcur", OFFLOAD_FUNC_KQV },
{ "krot", OFFLOAD_FUNC_KQ },
{ "qrot", OFFLOAD_FUNC_KQ },
{ "kpass", OFFLOAD_FUNC_KQ },
{ "qpass", OFFLOAD_FUNC_KQ },
{ "krotated", OFFLOAD_FUNC_KQ },
{ "qrotated", OFFLOAD_FUNC_KQ },
{ "krot", OFFLOAD_FUNC_KQV },
{ "qrot", OFFLOAD_FUNC_KQV },
{ "kpass", OFFLOAD_FUNC_KQV },
{ "qpass", OFFLOAD_FUNC_KQV },
{ "krotated", OFFLOAD_FUNC_KQV },
{ "qrotated", OFFLOAD_FUNC_KQV },
{ "q", OFFLOAD_FUNC_KQ },
{ "k", OFFLOAD_FUNC_KQ },
{ "kq", OFFLOAD_FUNC_KQ },
{ "kq_scaled", OFFLOAD_FUNC_KQ },
{ "kq_scaled_alibi", OFFLOAD_FUNC_KQ },
{ "kq_masked", OFFLOAD_FUNC_KQ },
{ "kq_soft_max", OFFLOAD_FUNC_V },
{ "kq_soft_max_ext", OFFLOAD_FUNC_V },
{ "v", OFFLOAD_FUNC_V },
{ "kqv", OFFLOAD_FUNC_V },
{ "kqv_merged", OFFLOAD_FUNC_V },
{ "kqv_merged_cont", OFFLOAD_FUNC_V },
{ "kqv_wo", OFFLOAD_FUNC_V },
{ "kqv_out", OFFLOAD_FUNC_V },
{ "q", OFFLOAD_FUNC_KQV },
{ "k", OFFLOAD_FUNC_KQV },
{ "kq", OFFLOAD_FUNC_KQV },
{ "kq_scaled", OFFLOAD_FUNC_KQV },
{ "kq_scaled_alibi", OFFLOAD_FUNC_KQV },
{ "kq_masked", OFFLOAD_FUNC_KQV },
{ "kq_soft_max", OFFLOAD_FUNC_KQV },
{ "kq_soft_max_ext", OFFLOAD_FUNC_KQV },
{ "v", OFFLOAD_FUNC_KQV },
{ "kqv", OFFLOAD_FUNC_KQV },
{ "kqv_merged", OFFLOAD_FUNC_KQV },
{ "kqv_merged_cont", OFFLOAD_FUNC_KQV },
{ "kqv_wo", OFFLOAD_FUNC_KQV },
{ "kqv_out", OFFLOAD_FUNC_KQV },
{ "ffn_inp", OFFLOAD_FUNC },
{ "ffn_norm", OFFLOAD_FUNC },
@ -5678,15 +5623,15 @@ static struct ggml_cgraph * llama_build_graph(
{ OFFLOAD_FUNC_NOP, "CPU" },
{ OFFLOAD_FUNC_OUT, "CPU" },
#ifdef GGML_USE_CUBLAS
{ OFFLOAD_FUNC, "GPU (CUDA)" },
{ OFFLOAD_FUNC_KQ, "GPU (CUDA) KQ" },
{ OFFLOAD_FUNC_V, "GPU (CUDA) V" },
{ OFFLOAD_FUNC_NR, "GPU (CUDA) NR" },
{ OFFLOAD_FUNC, "GPU (CUDA)" },
{ OFFLOAD_FUNC_FRC, "GPU (CUDA) FRC" },
{ OFFLOAD_FUNC_KQV, "GPU (CUDA) KQV" },
{ OFFLOAD_FUNC_NR, "GPU (CUDA) NR" },
{ OFFLOAD_FUNC_EMB, "GPU (CUDA) EMB" },
#else
{ OFFLOAD_FUNC, "CPU" },
{ OFFLOAD_FUNC_KQ, "CPU" },
{ OFFLOAD_FUNC_V, "CPU" },
{ OFFLOAD_FUNC_FRC, "CPU" },
{ OFFLOAD_FUNC_KQV, "CPU" },
{ OFFLOAD_FUNC_NR, "CPU" },
{ OFFLOAD_FUNC_EMB, "CPU" },
#endif // GGML_USE_CUBLAS
@ -5719,21 +5664,26 @@ static struct ggml_cgraph * llama_build_graph(
}
}
break;
case OFFLOAD_FUNC_FRC:
if (!lctx.cparams.offload_kqv) {
func_e = OFFLOAD_FUNC_NOP;
} break;
case OFFLOAD_FUNC_KQV:
if (!lctx.cparams.offload_kqv) {
func_e = OFFLOAD_FUNC_NOP;
} else {
if (n_gpu_layers < n_layer) {
if (il < i_gpu_start) {
func_e = OFFLOAD_FUNC_NOP;
}
}
}
break;
case OFFLOAD_FUNC_NR:
if (n_gpu_layers <= n_layer + 0) {
func_e = OFFLOAD_FUNC_NOP;
}
break;
case OFFLOAD_FUNC_V:
if (n_gpu_layers <= n_layer + 1) {
func_e = OFFLOAD_FUNC_NOP;
}
break;
case OFFLOAD_FUNC_KQ:
if (n_gpu_layers <= n_layer + 2) {
func_e = OFFLOAD_FUNC_NOP;
}
break;
case OFFLOAD_FUNC_EMB:
if (!offload_emb || n_gpu_layers < n_layer) {
func_e = OFFLOAD_FUNC_NOP;
@ -5755,8 +5705,8 @@ static struct ggml_cgraph * llama_build_graph(
case OFFLOAD_FUNC_NOP:
case OFFLOAD_FUNC_OUT: func = ggml_offload_nop; break;
case OFFLOAD_FUNC:
case OFFLOAD_FUNC_KQ:
case OFFLOAD_FUNC_V:
case OFFLOAD_FUNC_KQV:
case OFFLOAD_FUNC_FRC:
case OFFLOAD_FUNC_NR:
case OFFLOAD_FUNC_EMB: func = ggml_offload_gpu; break;
default: GGML_ASSERT(false);
@ -5942,6 +5892,7 @@ static int llama_decode_internal(
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
@ -5992,7 +5943,7 @@ static int llama_decode_internal(
n_threads = std::min(4, n_threads);
}
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3;
const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 1;
if (ggml_cpu_has_cublas() && fully_offloaded) {
n_threads = 1;
}
@ -8821,10 +8772,12 @@ struct llama_context_params llama_context_default_params() {
/*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f,
/*.yarn_orig_ctx =*/ 0,
/*.type_k =*/ GGML_TYPE_F16,
/*.type_v =*/ GGML_TYPE_F16,
/*.mul_mat_q =*/ true,
/*.f16_kv =*/ true,
/*.logits_all =*/ false,
/*.embedding =*/ false,
/*.offload_kqv =*/ true,
};
return result;
@ -8941,6 +8894,7 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.mul_mat_q = params.mul_mat_q;
cparams.offload_kqv = params.offload_kqv;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
@ -8974,19 +8928,36 @@ struct llama_context * llama_new_context_with_model(
ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
const ggml_type type_k = params.type_k;
const ggml_type type_v = params.type_v;
GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head() % ggml_blck_size(type_v) == 0);
// reserve memory for context buffers
if (!hparams.vocab_only) {
if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, cparams.n_ctx, model->n_gpu_layers)) {
if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, type_k, type_v, cparams.n_ctx, model->n_gpu_layers, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
}
{
const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v);
LLAMA_LOG_INFO("%s: kv self size = %7.2f MiB\n", __func__, memory_size / 1024.0 / 1024.0);
size_t memory_size_k = 0;
size_t memory_size_v = 0;
for (auto & k : ctx->kv_self.k_l) {
memory_size_k += ggml_nbytes(k);
}
for (auto & v : ctx->kv_self.v_l) {
memory_size_v += ggml_nbytes(v);
}
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
// resized during inference
@ -9057,8 +9028,12 @@ struct llama_context * llama_new_context_with_model(
}
size_t kv_vram_size = 0;
add_tensor(ctx->kv_self.k, kv_vram_size);
add_tensor(ctx->kv_self.v, kv_vram_size);
for (auto & k : ctx->kv_self.k_l) {
add_tensor(k, kv_vram_size);
}
for (auto & v : ctx->kv_self.v_l) {
add_tensor(v, kv_vram_size);
}
size_t ctx_vram_size = alloc_size + kv_vram_size;
size_t total_vram_size = model_vram_size + ctx_vram_size;
@ -9528,37 +9503,45 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
data_ctx->write(&kv_used, sizeof(kv_used));
if (kv_buf_size) {
const size_t elt_size = ggml_element_size(kv_self.k);
const size_t elt_size = ggml_element_size(kv_self.k_l[0]);
ggml_context * cpy_ctx = ggml_init({ 6*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
std::vector<uint8_t> kout3d_data(ggml_nbytes(kout3d), 0);
kout3d->data = kout3d_data.data();
std::vector<std::vector<uint8_t>> kout2d_data(n_layer);
std::vector<std::vector<uint8_t>> vout2d_data(n_layer);
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
std::vector<uint8_t> vout3d_data(ggml_nbytes(vout3d), 0);
vout3d->data = vout3d_data.data();
for (int il = 0; il < (int) n_layer; ++il) {
ggml_tensor * kout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head);
kout2d_data[il].resize(ggml_nbytes(kout2d));
kout2d->data = kout2d_data[il].data();
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
n_embd, kv_head, n_layer,
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
ggml_tensor * vout2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd);
vout2d_data[il].resize(ggml_nbytes(vout2d));
vout2d->data = vout2d_data[il].data();
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
kv_head, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il],
n_embd, kv_head,
elt_size*n_embd, 0);
ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il],
kv_head, n_embd,
elt_size*n_ctx, 0);
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k2d, kout2d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v2d, vout2d));
}
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, k3d, kout3d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, v3d, vout3d));
ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
ggml_free(cpy_ctx);
// our data is now in the kout3d_data and vout3d_data buffers
// our data is now in the kout2d_data and vout2d_data buffers
// write them to file
data_ctx->write(kout3d_data.data(), kout3d_data.size());
data_ctx->write(vout3d_data.data(), vout3d_data.size());
for (uint32_t il = 0; il < n_layer; ++il) {
data_ctx->write(kout2d_data[il].data(), kout2d_data[il].size());
data_ctx->write(vout2d_data[il].data(), vout2d_data[il].size());
}
}
for (uint32_t i = 0; i < kv_size; ++i) {
@ -9658,29 +9641,32 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
if (kv_buf_size) {
GGML_ASSERT(kv_self.buf.size == kv_buf_size);
const size_t elt_size = ggml_element_size(kv_self.k);
const size_t elt_size = ggml_element_size(kv_self.k_l[0]);
ggml_context * cpy_ctx = ggml_init({ 6*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
ggml_context * cpy_ctx = ggml_init({ 6*n_layer*ggml_tensor_overhead() + ggml_graph_overhead(), NULL, /* no_alloc */ true });
ggml_cgraph * gf = ggml_new_graph(cpy_ctx);
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_head, n_layer);
kin3d->data = (void *) inp;
inp += ggml_nbytes(kin3d);
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * kin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.k_l[il]->type, n_embd, kv_head);
kin2d->data = (void *) inp;
inp += ggml_nbytes(kin2d);
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_head, n_embd, n_layer);
vin3d->data = (void *) inp;
inp += ggml_nbytes(vin3d);
ggml_tensor * vin2d = ggml_new_tensor_2d(cpy_ctx, kv_self.v_l[il]->type, kv_head, n_embd);
vin2d->data = (void *) inp;
inp += ggml_nbytes(vin2d);
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
n_embd, kv_head, n_layer,
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
ggml_tensor * k2d = ggml_view_2d(cpy_ctx, kv_self.k_l[il],
n_embd, kv_head,
elt_size*n_embd, 0);
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
kv_head, n_embd, n_layer,
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
ggml_tensor * v2d = ggml_view_2d(cpy_ctx, kv_self.v_l[il],
kv_head, n_embd,
elt_size*n_ctx, 0);
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin2d, k2d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin2d, v2d));
}
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, kin3d, k3d));
ggml_build_forward_expand(gf, ggml_cpy(cpy_ctx, vin3d, v3d));
ggml_graph_compute_helper(ctx->work_buffer, gf, /*n_threads*/ 1);
ggml_free(cpy_ctx);

13
llama.h
View file

@ -42,7 +42,7 @@
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 2
#define LLAMA_SESSION_VERSION 3
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
@ -211,11 +211,14 @@ extern "C" {
float yarn_beta_slow; // YaRN high correction dim
uint32_t yarn_orig_ctx; // YaRN original context size
enum ggml_type type_k; // data type for K cache
enum ggml_type type_v; // data type for V cache
// Keep the booleans together to avoid misalignment during copy-by-value.
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
bool f16_kv; // use fp16 for KV cache, fp32 otherwise
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool embedding; // embedding mode only
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool embedding; // embedding mode only
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
};
// model quantization parameters