CUDA full GPU acceleration, KV cache in VRAM (#1827)

* Fixed CUDA RoPE

* ggml_cuda_mul_mat_vec_p021

* ggml_cuda_scale

* ggml_cuda_diag_mask_inf

* ggml_is_permuted

* ggml_cuda_cpy

* flatten rows for ggml_cuda_op

* Added a --low-vram option

* Fixed Windows performance

* Fixed LLAMA_CUDA_DMMV_Y > 1 for WizardLM
This commit is contained in:
Johannes Gäßler 2023-06-14 19:47:19 +02:00 committed by GitHub
parent 9254920265
commit 254a7a7a5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 853 additions and 149 deletions

View file

@ -331,6 +331,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} }
#else #else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n"); fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--low-vram" || arg == "-lv") {
#ifdef GGML_USE_CUBLAS
params.low_vram = true;
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
} else if (arg == "--no-mmap") { } else if (arg == "--no-mmap") {
params.use_mmap = false; params.use_mmap = false;
@ -479,6 +485,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -ts SPLIT --tensor-split SPLIT\n"); fprintf(stderr, " -ts SPLIT --tensor-split SPLIT\n");
fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
fprintf(stderr, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" ); fprintf(stderr, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" );
fprintf(stderr, " -lv, --low-vram don't allocate VRAM scratch buffer\n" );
#endif #endif
fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n"); fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n");
@ -528,6 +535,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
lparams.n_gpu_layers = params.n_gpu_layers; lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu; lparams.main_gpu = params.main_gpu;
memcpy(lparams.tensor_split, params.tensor_split, LLAMA_MAX_DEVICES*sizeof(float)); memcpy(lparams.tensor_split, params.tensor_split, LLAMA_MAX_DEVICES*sizeof(float));
lparams.low_vram = params.low_vram;
lparams.seed = params.seed; lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16; lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap; lparams.use_mmap = params.use_mmap;

View file

@ -21,15 +21,16 @@
int32_t get_num_physical_cores(); int32_t get_num_physical_cores();
struct gpt_params { struct gpt_params {
int32_t seed = -1; // RNG seed int32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores(); int32_t n_threads = get_num_physical_cores();
int32_t n_predict = -1; // new tokens to predict int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_gpu_layers = 0; // number of layers to store in VRAM int32_t n_gpu_layers = 0; // number of layers to store in VRAM
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
bool low_vram = 0; // if true, reduce VRAM usage at the cost of performance
// sampling parameters // sampling parameters
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

View file

@ -288,5 +288,6 @@ These options provide extra functionality and customization when running the LLa
- `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance. - `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS. - `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS. - `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
- `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS.
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains. - `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation. - `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.

View file

@ -289,6 +289,7 @@ Test();
- `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance. - `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS. - `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS. - `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
- `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS.
- `--embedding`: Enable the embedding mode. **Completion function doesn't work in this mode**. - `--embedding`: Enable the embedding mode. **Completion function doesn't work in this mode**.
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`; - `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`;
- `--port`: Set the port to listen. Default: `8080`. - `--port`: Set the port to listen. Default: `8080`.

View file

@ -405,6 +405,7 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params &params)
fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
fprintf(stderr, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" ); fprintf(stderr, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" );
fprintf(stderr, " -lv, --low-vram don't allocate VRAM scratch buffer\n" );
#endif #endif
fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
@ -537,6 +538,14 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
} }
#else #else
fprintf(stderr, "WARNING: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n"); fprintf(stderr, "WARNING: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
#endif // GGML_USE_CUBLAS
}
else if (arg == "--low-vram" || arg == "-lv")
{
#ifdef GGML_USE_CUBLAS
params.low_vram = true;
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
} }
else if (arg == "--main-gpu" || arg == "-mg") else if (arg == "--main-gpu" || arg == "-mg")

File diff suppressed because it is too large Load diff

View file

@ -28,8 +28,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
void ggml_cuda_free_data(struct ggml_tensor * tensor); void ggml_cuda_free_data(struct ggml_tensor * tensor);
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor); void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
void ggml_cuda_set_main_device(int main_device); void ggml_cuda_set_main_device(int main_device);
void ggml_cuda_set_scratch_size(size_t scratch_size); void ggml_cuda_set_scratch_size(size_t scratch_size);
void ggml_cuda_free_scratch(void);
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor); bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
#ifdef __cplusplus #ifdef __cplusplus

6
ggml.c
View file

@ -3939,6 +3939,12 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
} }
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
}
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");

1
ggml.h
View file

@ -485,6 +485,7 @@ extern "C" {
GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor); GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
// use this to compute the memory overhead of a tensor // use this to compute the memory overhead of a tensor
GGML_API size_t ggml_tensor_overhead(void); GGML_API size_t ggml_tensor_overhead(void);

159
llama.cpp
View file

@ -165,6 +165,11 @@ struct llama_kv_cache {
if (ctx) { if (ctx) {
ggml_free(ctx); ggml_free(ctx);
} }
#ifdef GGML_USE_CUBLAS
ggml_cuda_free_data(k);
ggml_cuda_free_data(v);
#endif // GGML_USE_CUBLAS
} }
}; };
@ -210,6 +215,7 @@ struct llama_model {
for (size_t i = 0; i < tensors_by_name.size(); ++i) { for (size_t i = 0; i < tensors_by_name.size(); ++i) {
ggml_cuda_free_data(tensors_by_name[i].second); ggml_cuda_free_data(tensors_by_name[i].second);
} }
ggml_cuda_free_scratch();
#elif defined(GGML_USE_CLBLAST) #elif defined(GGML_USE_CLBLAST)
for (size_t i = 0; i < tensors_by_name.size(); ++i) { for (size_t i = 0; i < tensors_by_name.size(); ++i) {
ggml_cl_free_data(tensors_by_name[i].second); ggml_cl_free_data(tensors_by_name[i].second);
@ -867,7 +873,8 @@ static bool kv_cache_init(
const struct llama_hparams & hparams, const struct llama_hparams & hparams,
struct llama_kv_cache & cache, struct llama_kv_cache & cache,
ggml_type wtype, ggml_type wtype,
int n_ctx) { int n_ctx,
int n_gpu_layers) {
const int n_embd = hparams.n_embd; const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer; const int n_layer = hparams.n_layer;
@ -893,6 +900,15 @@ static bool kv_cache_init(
ggml_set_name(cache.k, "cache_k"); ggml_set_name(cache.k, "cache_k");
ggml_set_name(cache.v, "cache_v"); ggml_set_name(cache.v, "cache_v");
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer + 1) {
ggml_cuda_assign_buffers_no_scratch(cache.v);
}
if (n_gpu_layers > n_layer + 2) {
ggml_cuda_assign_buffers_no_scratch(cache.k);
}
#endif // GGML_USE_CUBLAS
return true; return true;
} }
@ -903,6 +919,7 @@ struct llama_context_params llama_context_default_params() {
/*.gpu_layers =*/ 0, /*.gpu_layers =*/ 0,
/*.main_gpu =*/ 0, /*.main_gpu =*/ 0,
/*.tensor_split =*/ {0}, /*.tensor_split =*/ {0},
/*.low_vram =*/ false,
/*.seed =*/ -1, /*.seed =*/ -1,
/*.f16_kv =*/ true, /*.f16_kv =*/ true,
/*.logits_all =*/ false, /*.logits_all =*/ false,
@ -1011,6 +1028,7 @@ static void llama_model_load_internal(
int n_gpu_layers, int n_gpu_layers,
int main_gpu, int main_gpu,
const float * tensor_split, const float * tensor_split,
bool low_vram,
ggml_type memory_type, ggml_type memory_type,
bool use_mmap, bool use_mmap,
bool use_mlock, bool use_mlock,
@ -1137,18 +1155,34 @@ static void llama_model_load_internal(
ml->ggml_ctx = ctx; ml->ggml_ctx = ctx;
model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU); model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.norm = ml->get_tensor("norm.weight", {n_embd}, GGML_BACKEND_CPU);
// "output" tensor // "output" tensor
{ {
ggml_backend backend_norm;
ggml_backend backend_output; ggml_backend backend_output;
if (n_gpu_layers > int(n_layer)) { // NOLINT if (n_gpu_layers > int(n_layer)) { // NOLINT
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
#else
backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
#endif // _WIN32
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
} else { } else {
backend_norm = GGML_BACKEND_CPU;
backend_output = GGML_BACKEND_CPU; backend_output = GGML_BACKEND_CPU;
} }
model.norm = ml->get_tensor("norm.weight", {n_embd}, backend_norm);
model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output);
if (backend_norm == GGML_BACKEND_GPU) {
vram_weights += ggml_nbytes(model.norm);
}
if (backend_output == GGML_BACKEND_GPU_SPLIT) {
vram_weights += ggml_nbytes(model.output);
}
} }
const int i_gpu_start = n_layer - n_gpu_layers; const int i_gpu_start = n_layer - n_gpu_layers;
@ -1208,22 +1242,47 @@ static void llama_model_load_internal(
(void) vram_scratch; (void) vram_scratch;
(void) n_batch; (void) n_batch;
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
vram_scratch = n_batch * MB; if (low_vram) {
ggml_cuda_set_scratch_size(vram_scratch); fprintf(stderr, "%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
if (n_gpu_layers > 0) { ggml_cuda_set_scratch_size(0); // disable scratch
fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n", } else {
__func__, vram_scratch / MB); vram_scratch = n_batch * MB;
ggml_cuda_set_scratch_size(vram_scratch);
if (n_gpu_layers > 0) {
fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n",
__func__, vram_scratch / MB);
}
} }
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
fprintf(stderr, "%s: offloading %d layers to GPU\n", __func__, n_gpu); fprintf(stderr, "%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
if (n_gpu_layers > (int) hparams.n_layer) { if (n_gpu_layers > (int) hparams.n_layer) {
fprintf(stderr, "%s: offloading output layer to GPU\n", __func__); fprintf(stderr, "%s: offloading non-repeating layers to GPU\n", __func__);
} }
size_t vram_kv_cache = 0;
if (n_gpu_layers > (int) hparams.n_layer + 1) {
if (low_vram) {
fprintf(stderr, "%s: cannot offload v cache to GPU due to low VRAM option\n", __func__);
} else {
fprintf(stderr, "%s: offloading v cache to GPU\n", __func__);
vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2;
}
}
if (n_gpu_layers > (int) hparams.n_layer + 2) {
if (low_vram) {
fprintf(stderr, "%s: cannot offload k cache to GPU due to low VRAM option\n", __func__);
} else {
fprintf(stderr, "%s: offloading k cache to GPU\n", __func__);
vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2;
}
}
const int max_offloadable_layers = low_vram ? hparams.n_layer + 1 : hparams.n_layer + 3;
fprintf(stderr, "%s: offloaded %d/%d layers to GPU\n",
__func__, std::min(n_gpu_layers, max_offloadable_layers), hparams.n_layer + 3);
fprintf(stderr, "%s: total VRAM used: %zu MB\n", fprintf(stderr, "%s: total VRAM used: %zu MB\n",
__func__, (vram_weights + vram_scratch + MB - 1) / MB); // round up __func__, (vram_weights + vram_scratch + vram_kv_cache + MB - 1) / MB); // round up
#else #else
(void) n_gpu_layers; (void) n_gpu_layers;
#endif #endif
@ -1262,6 +1321,7 @@ static bool llama_model_load(
int n_gpu_layers, int n_gpu_layers,
int main_gpu, int main_gpu,
float * tensor_split, float * tensor_split,
bool low_vram,
ggml_type memory_type, ggml_type memory_type,
bool use_mmap, bool use_mmap,
bool use_mlock, bool use_mlock,
@ -1269,7 +1329,7 @@ static bool llama_model_load(
llama_progress_callback progress_callback, llama_progress_callback progress_callback,
void *progress_callback_user_data) { void *progress_callback_user_data) {
try { try {
llama_model_load_internal(fname, lctx, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, memory_type, llama_model_load_internal(fname, lctx, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, low_vram, memory_type,
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
return true; return true;
} catch (const std::exception & err) { } catch (const std::exception & err) {
@ -1345,12 +1405,33 @@ static bool llama_eval_internal(
const int i_gpu_start = n_layer - n_gpu_layers; const int i_gpu_start = n_layer - n_gpu_layers;
(void) i_gpu_start; (void) i_gpu_start;
// offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded
//
// with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
// in that case ggml_cuda_assign_buffers has no effect
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop;
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer) {
offload_func_nr = ggml_cuda_assign_buffers;
}
if (n_gpu_layers > n_layer + 1) {
offload_func_v = ggml_cuda_assign_buffers;
}
if (n_gpu_layers > n_layer + 2) {
offload_func_kq = ggml_cuda_assign_buffers;
}
#endif // GGML_USE_CUBLAS
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
offload_func_t offload_func = llama_nop; offload_func_t offload_func = llama_nop;
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
if (il >= i_gpu_start) { if (il >= i_gpu_start) {
offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU offload_func = ggml_cuda_assign_buffers;
} }
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
@ -1373,31 +1454,42 @@ static bool llama_eval_internal(
// self-attention // self-attention
{ {
// compute Q and K and RoPE them // compute Q and K and RoPE them
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
// offload_func(tmpq);
ggml_set_name(tmpq, "tmpq");
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
// offload_func(tmpk); offload_func_kq(tmpk);
ggml_set_name(tmpk, "tmpk"); ggml_set_name(tmpk, "tmpk");
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq");
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0); struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0);
offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur"); ggml_set_name(Kcur, "Kcur");
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0); struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur"); ggml_set_name(Qcur, "Qcur");
// store key and value to memory // store key and value to memory
{ {
// compute the transposed [N, n_embd] V matrix // compute the transposed [N, n_embd] V matrix
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
offload_func_v(tmpv);
ggml_set_name(tmpv, "tmpv");
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd, N));
offload_func_v(Vcur);
ggml_set_name(Vcur, "Vcur"); ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
offload_func_kq(k);
ggml_set_name(k, "k"); ggml_set_name(k, "k");
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
( n_ctx)*ggml_element_size(kv_self.v), ( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
offload_func_v(v);
ggml_set_name(v, "v"); ggml_set_name(v, "v");
// important: storing RoPE-ed version of K in the KV cache! // important: storing RoPE-ed version of K in the KV cache!
@ -1409,6 +1501,7 @@ static bool llama_eval_internal(
ggml_permute(ctx0, ggml_permute(ctx0,
Qcur, Qcur,
0, 2, 1, 3); 0, 2, 1, 3);
offload_func_kq(Q);
ggml_set_name(Q, "Q"); ggml_set_name(Q, "Q");
struct ggml_tensor * K = struct ggml_tensor * K =
@ -1417,10 +1510,12 @@ static bool llama_eval_internal(
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
n_embd/n_head, n_head, n_past + N), n_embd/n_head, n_head, n_past + N),
0, 2, 1, 3); 0, 2, 1, 3);
offload_func_kq(K);
ggml_set_name(K, "K"); ggml_set_name(K, "K");
// K * Q // K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
offload_func_kq(KQ);
ggml_set_name(KQ, "KQ"); ggml_set_name(KQ, "KQ");
// KQ_scaled = KQ / sqrt(n_embd/n_head) // KQ_scaled = KQ / sqrt(n_embd/n_head)
@ -1429,14 +1524,17 @@ static bool llama_eval_internal(
// KQ_scaled shape [n_past + N, N, n_head, 1] // KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
offload_func_kq(KQ_scaled);
ggml_set_name(KQ_scaled, "KQ_scaled"); ggml_set_name(KQ_scaled, "KQ_scaled");
// KQ_masked = mask_past(KQ_scaled) // KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
offload_func_kq(KQ_masked);
ggml_set_name(KQ_masked, "KQ_masked"); ggml_set_name(KQ_masked, "KQ_masked");
// KQ = soft_max(KQ_masked) // KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
offload_func_v(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max"); ggml_set_name(KQ_soft_max, "KQ_soft_max");
// split cached V into n_head heads // split cached V into n_head heads
@ -1446,10 +1544,12 @@ static bool llama_eval_internal(
n_ctx*ggml_element_size(kv_self.v), n_ctx*ggml_element_size(kv_self.v),
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head, n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
il*n_ctx*ggml_element_size(kv_self.v)*n_embd); il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
offload_func_v(V);
ggml_set_name(V, "V"); ggml_set_name(V, "V");
#if 1 #if 1
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
offload_func_v(KQV);
ggml_set_name(KQV, "KQV"); ggml_set_name(KQV, "KQV");
#else #else
// make V contiguous in memory to speed up the matmul, however we waste time on the copy // make V contiguous in memory to speed up the matmul, however we waste time on the copy
@ -1461,12 +1561,14 @@ static bool llama_eval_internal(
// KQV_merged = KQV.permute(0, 2, 1, 3) // KQV_merged = KQV.permute(0, 2, 1, 3)
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
offload_func_v(KQV_merged);
ggml_set_name(KQV_merged, "KQV_merged"); ggml_set_name(KQV_merged, "KQV_merged");
// cur = KQV_merged.contiguous().view(n_embd, N) // cur = KQV_merged.contiguous().view(n_embd, N)
cur = ggml_cpy(ctx0, cur = ggml_cpy(ctx0,
KQV_merged, KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
offload_func_v(cur);
ggml_set_name(cur, "KQV_merged_contiguous"); ggml_set_name(cur, "KQV_merged_contiguous");
// projection (no bias) // projection (no bias)
@ -1478,7 +1580,6 @@ static bool llama_eval_internal(
} }
lctx.use_buf(ctx0, 1); lctx.use_buf(ctx0, 1);
//ggml_cuda_set_scratch(1);
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
offload_func(inpFF); offload_func(inpFF);
@ -1536,32 +1637,24 @@ static bool llama_eval_internal(
} }
lctx.use_buf(ctx0, 0); lctx.use_buf(ctx0, 0);
//ggml_cuda_set_scratch(0);
// used at the end to optionally extract the embeddings // used at the end to optionally extract the embeddings
struct ggml_tensor * embeddings = NULL; struct ggml_tensor * embeddings = NULL;
offload_func_t offload_func = llama_nop;
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer) {
offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU
}
#endif // GGML_USE_CUBLAS
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL);
offload_func(cur); offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_inpL"); ggml_set_name(cur, "rms_norm_inpL");
cur = ggml_rms_norm(ctx0, cur); cur = ggml_rms_norm(ctx0, cur);
offload_func(cur); offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_after"); ggml_set_name(cur, "rms_norm_after");
// cur = cur*norm(broadcasted) // cur = cur*norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.norm); cur = ggml_mul(ctx0, cur, model.norm);
offload_func(cur); offload_func_nr(cur);
ggml_set_name(cur, "result_norm"); ggml_set_name(cur, "result_norm");
embeddings = cur; embeddings = cur;
@ -2552,8 +2645,8 @@ struct llama_context * llama_init_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_batch, params.n_gpu_layers, if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_batch, params.n_gpu_layers, params.main_gpu,
params.main_gpu, params.tensor_split, memory_type, params.use_mmap, params.use_mlock, params.tensor_split, params.low_vram, memory_type, params.use_mmap, params.use_mlock,
params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { params.vocab_only, params.progress_callback, params.progress_callback_user_data)) {
fprintf(stderr, "%s: failed to load model\n", __func__); fprintf(stderr, "%s: failed to load model\n", __func__);
llama_free(ctx); llama_free(ctx);
@ -2562,7 +2655,7 @@ struct llama_context * llama_init_from_file(
// reserve memory for context buffers // reserve memory for context buffers
if (!params.vocab_only) { if (!params.vocab_only) {
if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) { if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx); llama_free(ctx);
return nullptr; return nullptr;

View file

@ -77,6 +77,7 @@ extern "C" {
int n_gpu_layers; // number of layers to store in VRAM int n_gpu_layers; // number of layers to store in VRAM
int main_gpu; // the GPU that is used for scratch and small tensors int main_gpu; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
bool low_vram; // if true, reduce VRAM usage at the cost of performance
int seed; // RNG seed, -1 for random int seed; // RNG seed, -1 for random
bool f16_kv; // use fp16 for KV cache bool f16_kv; // use fp16 for KV cache