From 0728c5a8b9569183ffca0399caac099afef87595 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 31 Jul 2023 15:44:35 +0200 Subject: [PATCH] CUDA: mmq CLI option, fixed mmq build issues (#2453) --- CMakeLists.txt | 16 ++++++++++------ Makefile | 6 +++--- README.md | 4 +++- examples/common.cpp | 16 +++++++++++++--- examples/common.h | 1 + examples/server/server.cpp | 15 +++++++++++++-- ggml-cuda.cu | 24 ++++++++++++++---------- ggml-cuda.h | 1 + llama.cpp | 10 ++++++++-- llama.h | 1 + 10 files changed, 67 insertions(+), 27 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 57678a302..4ecb3d586 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,7 +68,7 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework option(LLAMA_BLAS "llama: use BLAS" OFF) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") option(LLAMA_CUBLAS "llama: use CUDA" OFF) -option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF) +#option(LLAMA_CUDA_CUBLAS "llama: use cuBLAS for prompt processing" OFF) set(LLAMA_CUDA_MMQ_Y "64" CACHE STRING "llama: y tile size for mmq CUDA kernels") option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") @@ -253,9 +253,9 @@ if (LLAMA_CUBLAS) set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) - if (LLAMA_CUDA_CUBLAS) - add_compile_definitions(GGML_CUDA_CUBLAS) - endif() +# if (LLAMA_CUDA_CUBLAS) +# add_compile_definitions(GGML_CUDA_CUBLAS) +# endif() add_compile_definitions(GGML_CUDA_MMQ_Y=${LLAMA_CUDA_MMQ_Y}) if (LLAMA_CUDA_FORCE_DMMV) add_compile_definitions(GGML_CUDA_FORCE_DMMV) @@ -277,10 +277,14 @@ if (LLAMA_CUBLAS) endif() if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + # 52 == lowest CUDA 12 standard + # 60 == f16 CUDA intrinsics + # 61 == integer CUDA intrinsics + # 70 == (assumed) compute capability at which unrolling a loop in mul_mat_q kernels is faster if (LLAMA_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61") # needed for f16 CUDA intrinsics + set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics else() - set(CMAKE_CUDA_ARCHITECTURES "52;61") # lowest CUDA 12 standard + lowest for integer intrinsics + set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics endif() endif() message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") diff --git a/Makefile b/Makefile index 616c2d9b8..ebeadfdd0 100644 --- a/Makefile +++ b/Makefile @@ -236,9 +236,9 @@ ifdef LLAMA_CUDA_MMQ_Y else NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64 endif # LLAMA_CUDA_MMQ_Y -ifdef LLAMA_CUDA_CUBLAS - NVCCFLAGS += -DGGML_CUDA_CUBLAS -endif # LLAMA_CUDA_CUBLAS +#ifdef LLAMA_CUDA_CUBLAS +# NVCCFLAGS += -DGGML_CUDA_CUBLAS +#endif # LLAMA_CUDA_CUBLAS ifdef LLAMA_CUDA_CCBIN NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN) endif diff --git a/README.md b/README.md index 42fc42b05..b231d24b8 100644 --- a/README.md +++ b/README.md @@ -400,9 +400,11 @@ Building the program with BLAS support may lead to some performance improvements The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used. The following compilation options are also available to tweak performance: + | Option | Legal values | Default | Description | |-------------------------|------------------------|---------|-------------| - | LLAMA_CUDA_CUBLAS | Boolean | false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). | | LLAMA_CUDA_MMQ_Y | Positive integer >= 32 | 64 | Tile size in y direction when using the custom CUDA kernels for prompt processing. Higher values can be faster depending on the amount of shared memory available. Power of 2 heavily recommended. | | LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | diff --git a/examples/common.cpp b/examples/common.cpp index fe7308b17..e6439841d 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -352,7 +352,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { #ifdef GGML_USE_CUBLAS params.main_gpu = std::stoi(argv[i]); #else - fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n"); + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n"); #endif } else if (arg == "--tensor-split" || arg == "-ts") { if (++i >= argc) { @@ -376,13 +376,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } } #else - fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n"); + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n"); +#endif // GGML_USE_CUBLAS + } else if (arg == "--mul-mat-q" || arg == "-mmq") { +#ifdef GGML_USE_CUBLAS + params.mul_mat_q = true; +#else + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to use mul_mat_q kernels.\n"); #endif // GGML_USE_CUBLAS } else if (arg == "--low-vram" || arg == "-lv") { #ifdef GGML_USE_CUBLAS params.low_vram = true; #else - fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n"); + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n"); #endif // GGML_USE_CUBLAS } else if (arg == "--no-mmap") { params.use_mmap = false; @@ -585,6 +591,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" ); fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n" ); + fprintf(stdout, " -mmq, --mul-mat-q use experimental mul_mat_q CUDA kernels instead of cuBLAS. TEMP!!!\n" ); + fprintf(stdout, " Reduces VRAM usage by 700/970/1430 MiB for 7b/13b/33b but prompt processing speed\n" ); + fprintf(stdout, " is still suboptimal, especially q2_K, q3_K, q5_K, and q6_K.\n" ); #endif fprintf(stdout, " --mtest compute maximum memory usage\n"); fprintf(stdout, " --export export the computation graph to 'llama.ggml'\n"); @@ -637,6 +646,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param lparams.main_gpu = params.main_gpu; lparams.tensor_split = params.tensor_split; lparams.low_vram = params.low_vram; + lparams.mul_mat_q = params.mul_mat_q; lparams.seed = params.seed; lparams.f16_kv = params.memory_f16; lparams.use_mmap = params.use_mmap; diff --git a/examples/common.h b/examples/common.h index 1184f32df..974484207 100644 --- a/examples/common.h +++ b/examples/common.h @@ -74,6 +74,7 @@ struct gpt_params { size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score bool low_vram = false; // if true, reduce VRAM usage at the cost of performance + bool mul_mat_q = false; // if true, use experimental mul_mat_q kernels bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 83c03065a..c0725088f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -631,6 +631,9 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n"); fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n"); + fprintf(stdout, " -mmq, --mul-mat-q use experimental mul_mat_q CUDA kernels instead of cuBLAS. TEMP!!!\n" ); + fprintf(stdout, " Reduces VRAM usage by 700/970/1430 MiB for 7b/13b/33b but prompt processing speed\n" ); + fprintf(stdout, " is still suboptimal, especially q2_K, q3_K, q5_K, and q6_K.\n" ); #endif fprintf(stdout, " -m FNAME, --model FNAME\n"); fprintf(stdout, " model path (default: %s)\n", params.model.c_str()); @@ -827,7 +830,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } } #else - LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.", {}); + LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {}); #endif // GGML_USE_CUBLAS } else if (arg == "--low-vram" || arg == "-lv") @@ -835,7 +838,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, #ifdef GGML_USE_CUBLAS params.low_vram = true; #else - fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n"); + LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n", {}); +#endif // GGML_USE_CUBLAS + } + else if (arg == "--mul-mat-q" || arg == "-mmq") + { +#ifdef GGML_USE_CUBLAS + params.mul_mat_q = true; +#else + LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to use mul_mat_q kernels.\n", {}); #endif // GGML_USE_CUBLAS } else if (arg == "--main-gpu" || arg == "-mg") diff --git a/ggml-cuda.cu b/ggml-cuda.cu index bcdff3640..f11fbe57c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3898,10 +3898,9 @@ static size_t g_scratch_offset = 0; static int g_device_count = -1; static int g_main_device = 0; -#ifndef GGML_CUDA_FORCE_DMMV static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; -#endif static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; +static bool g_mul_mat_q = false; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; @@ -3923,9 +3922,7 @@ void ggml_init_cublas() { g_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; -#ifndef GGML_CUDA_FORCE_DMMV g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; -#endif } for (int id = 0; id < g_device_count; ++id) { g_tensor_split[id] /= total_vram; @@ -4278,6 +4275,7 @@ inline void ggml_cuda_op_mul_mat_vec( #ifdef GGML_CUDA_FORCE_DMMV const bool use_mul_mat_vec_q = false; + (void) g_compute_capabilities[0]; #else int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -5021,12 +5019,14 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false); } else { -#ifdef GGML_CUDA_CUBLAS - const bool use_mul_mat_q = false; -#else - const bool use_mul_mat_q = ggml_is_quantized(src0->type); -#endif // GGML_CUDA_CUBLAS - if (use_mul_mat_q) { + int min_compute_capability = INT_MAX; + for (int id = 0; id < g_device_count; ++id) { + if (min_compute_capability > g_compute_capabilities[id]) { + min_compute_capability = g_compute_capabilities[id]; + } + } + + if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false); } else { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); @@ -5320,6 +5320,10 @@ void ggml_cuda_set_main_device(int main_device) { } } +void ggml_cuda_set_mul_mat_q(bool mul_mat_q) { + g_mul_mat_q = mul_mat_q; +} + void ggml_cuda_set_scratch_size(size_t scratch_size) { g_scratch_size = scratch_size; } diff --git a/ggml-cuda.h b/ggml-cuda.h index 3c1e8deb6..72d7afa46 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -27,6 +27,7 @@ void ggml_cuda_assign_buffers(struct ggml_tensor * tensor); void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor); void ggml_cuda_set_main_device(int main_device); +void ggml_cuda_set_mul_mat_q(bool mul_mat_q); void ggml_cuda_set_scratch_size(size_t scratch_size); void ggml_cuda_free_scratch(void); bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor); diff --git a/llama.cpp b/llama.cpp index 50da4274f..d427054dd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -901,6 +901,7 @@ struct llama_context_params llama_context_default_params() { /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.low_vram =*/ false, + /*.mul_mat_q =*/ false, /*.f16_kv =*/ true, /*.logits_all =*/ false, /*.vocab_only =*/ false, @@ -1028,6 +1029,7 @@ static void llama_model_load_internal( int n_gpu_layers, int main_gpu, const float * tensor_split, + const bool mul_mat_q, float rope_freq_base, float rope_freq_scale, bool low_vram, @@ -1156,9 +1158,11 @@ static void llama_model_load_internal( } (void) main_gpu; + (void) mul_mat_q; #if defined(GGML_USE_CUBLAS) fprintf(stderr, "%s: using CUDA for GPU acceleration\n", __func__); ggml_cuda_set_main_device(main_gpu); + ggml_cuda_set_mul_mat_q(mul_mat_q); #define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT #elif defined(GGML_USE_CLBLAST) @@ -1367,6 +1371,7 @@ static bool llama_model_load( int n_gpu_layers, int main_gpu, const float * tensor_split, + const bool mul_mat_q, float rope_freq_base, float rope_freq_scale, bool low_vram, @@ -1377,7 +1382,8 @@ static bool llama_model_load( llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type, + llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, + main_gpu, tensor_split, mul_mat_q, rope_freq_base, rope_freq_scale, low_vram, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); return true; } catch (const std::exception & err) { @@ -3192,7 +3198,7 @@ struct llama_model * llama_load_model_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers, - params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram, + params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale,params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { delete model; diff --git a/llama.h b/llama.h index df46f9b9c..fa1977f2d 100644 --- a/llama.h +++ b/llama.h @@ -108,6 +108,7 @@ extern "C" { // Keep the booleans together to avoid misalignment during copy-by-value. bool low_vram; // if true, reduce VRAM usage at the cost of performance + bool mul_mat_q; // if true, use experimental mul_mat_q kernels bool f16_kv; // use fp16 for KV cache bool logits_all; // the llama_eval() call computes all logits, not just the last one bool vocab_only; // only load the vocabulary, no weights