From 254a7a7a5ff4c874ff8488f1f5cbdd7e9c89d682 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 14 Jun 2023 19:47:19 +0200 Subject: [PATCH] CUDA full GPU acceleration, KV cache in VRAM (#1827) * Fixed CUDA RoPE * ggml_cuda_mul_mat_vec_p021 * ggml_cuda_scale * ggml_cuda_diag_mask_inf * ggml_is_permuted * ggml_cuda_cpy * flatten rows for ggml_cuda_op * Added a --low-vram option * Fixed Windows performance * Fixed LLAMA_CUDA_DMMV_Y > 1 for WizardLM --- examples/common.cpp | 8 + examples/common.h | 17 +- examples/main/README.md | 1 + examples/server/README.md | 1 + examples/server/server.cpp | 9 + ggml-cuda.cu | 797 ++++++++++++++++++++++++++++++++----- ggml-cuda.h | 2 + ggml.c | 6 + ggml.h | 1 + llama.cpp | 159 ++++++-- llama.h | 1 + 11 files changed, 853 insertions(+), 149 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index df69f2736..dc69e5373 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -331,6 +331,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } #else fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n"); +#endif // GGML_USE_CUBLAS + } else if (arg == "--low-vram" || arg == "-lv") { +#ifdef GGML_USE_CUBLAS + params.low_vram = true; +#else + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n"); #endif // GGML_USE_CUBLAS } else if (arg == "--no-mmap") { params.use_mmap = false; @@ -479,6 +485,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -ts SPLIT --tensor-split SPLIT\n"); fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); fprintf(stderr, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" ); + fprintf(stderr, " -lv, --low-vram don't allocate VRAM scratch buffer\n" ); #endif fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n"); @@ -528,6 +535,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { lparams.n_gpu_layers = params.n_gpu_layers; lparams.main_gpu = params.main_gpu; memcpy(lparams.tensor_split, params.tensor_split, LLAMA_MAX_DEVICES*sizeof(float)); + lparams.low_vram = params.low_vram; lparams.seed = params.seed; lparams.f16_kv = params.memory_f16; lparams.use_mmap = params.use_mmap; diff --git a/examples/common.h b/examples/common.h index 6fedb414a..6c2953cb2 100644 --- a/examples/common.h +++ b/examples/common.h @@ -21,15 +21,16 @@ int32_t get_num_physical_cores(); struct gpt_params { - int32_t seed = -1; // RNG seed - int32_t n_threads = get_num_physical_cores(); - int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 512; // context size - int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_gpu_layers = 0; // number of layers to store in VRAM - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + int32_t seed = -1; // RNG seed + int32_t n_threads = get_num_physical_cores(); + int32_t n_predict = -1; // new tokens to predict + int32_t n_ctx = 512; // context size + int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_gpu_layers = 0; // number of layers to store in VRAM + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs + bool low_vram = 0; // if true, reduce VRAM usage at the cost of performance // sampling parameters std::unordered_map logit_bias; // logit bias for specific tokens diff --git a/examples/main/README.md b/examples/main/README.md index 149d507a8..b6d3212fe 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -288,5 +288,6 @@ These options provide extra functionality and customization when running the LLa - `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance. - `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS. - `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS. +- `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS. - `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains. - `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation. diff --git a/examples/server/README.md b/examples/server/README.md index b011302fc..7dabac9cf 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -289,6 +289,7 @@ Test(); - `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance. - `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS. - `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS. +- `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS. - `--embedding`: Enable the embedding mode. **Completion function doesn't work in this mode**. - `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`; - `--port`: Set the port to listen. Default: `8080`. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 31d8087ef..872750053 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -405,6 +405,7 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms) fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); fprintf(stderr, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" ); + fprintf(stderr, " -lv, --low-vram don't allocate VRAM scratch buffer\n" ); #endif fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); @@ -537,6 +538,14 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para } #else fprintf(stderr, "WARNING: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n"); +#endif // GGML_USE_CUBLAS + } + else if (arg == "--low-vram" || arg == "-lv") + { +#ifdef GGML_USE_CUBLAS + params.low_vram = true; +#else + fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n"); #endif // GGML_USE_CUBLAS } else if (arg == "--main-gpu" || arg == "-mg") diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 3b9a5ddfb..0565571f4 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -48,6 +49,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v); +typedef void (*cpy_kernel_t)(const char * cx, char * cdst); typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); typedef void (*ggml_cuda_op_t)( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i, @@ -151,7 +153,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_ADD_BLOCK_SIZE 256 #define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 +#define CUDA_CPY_BLOCK_SIZE 32 +#define CUDA_SCALE_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256 +#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 // dmmv = dequantize_mul_mat_vec @@ -655,10 +660,15 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) } template -static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { +static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols, const int nrows) { // qk = quantized weights per x block // qr = number of quantized weights per data value in x block - const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int row = blockIdx.y*blockDim.y + threadIdx.y; + + if (row >= nrows) { + return; + } + const int tid = threadIdx.x; const int iter_stride = 2*GGML_CUDA_DMMV_X; @@ -703,8 +713,13 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, } template -static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols) { - const int row = blockIdx.x*blockDim.y + threadIdx.y; +static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols, const int nrows) { + const int row = blockIdx.y*blockDim.y + threadIdx.y; + + if (row >= nrows) { + return; + } + const int tid = threadIdx.x; const int iter_stride = QK_K; @@ -737,6 +752,139 @@ static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y } } +static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) { + const half * x = (half *) vx; + + const int row_x = blockDim.y*blockIdx.y + threadIdx.y; + const int channel = blockDim.z*blockIdx.z + threadIdx.z; + + const int nrows_y = ncols_x; + const int nrows_dst = nrows_x; + const int row_dst = row_x; + + float tmp = 0.0f; + + for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { + const int col_x = col_x0 + threadIdx.x; + + if (col_x >= ncols_x) { + break; + } + + // x is transposed and permuted + const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x; + const float xi = __half2float(x[ix]); + + const int row_y = col_x; + + + // y is not transposed but permuted + const int iy = channel*nrows_y + row_y; + + tmp += xi * y[iy]; + } + + // dst is not transposed and not permuted + const int idst = channel*nrows_dst + row_dst; + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[idst] = tmp; + } +} + +static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous + const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, + const int row_stride_x, const int nchannels_x, const int channel_stride_x) { + + const half * x = (half *) vx; + + const int row_x = blockDim.y*blockIdx.y + threadIdx.y; + const int channel = blockDim.z*blockIdx.z + threadIdx.z; + + const int nrows_y = ncols_x; + const int nrows_dst = nrows_x; + const int row_dst = row_x; + + const int idst = channel*nrows_dst + row_dst; + + float tmp = 0.0f; + + for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { + const int col_x = col_x0 + threadIdx.x; + + if (col_x >= ncols_x) { + break; + } + + const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x; + const float xi = __half2float(x[ix]); + + const int row_y = col_x; + + const int iy = channel*nrows_y + row_y; + + tmp += xi * y[iy]; + } + + // sum up partial sums and write back result + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[idst] = tmp; + } +} + +static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { + const float * xi = (float *) cxi; + float * dsti = (float *) cdsti; + + *dsti = *xi; +} + +static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { + const float * xi = (float *) cxi; + half * dsti = (half *) cdsti; + + *dsti = __float2half(*xi); +} + +template +static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor + // then combine those indices with the corresponding byte offsets to get the total offsets + const int i02 = i / (ne00*ne01); + const int i01 = (i - i02*ne01*ne00) / ne00; + const int i00 = i - i02*ne01*ne00 - i01*ne00; + const int x_offset = i00*nb00 + i01*nb01 + i02*nb02; + + const int i12 = i / (ne10*ne11); + const int i11 = (i - i12*ne10*ne11) / ne10; + const int i10 = i - i12*ne10*ne11 - i11*ne10; + const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; + + cpy_1(cx + x_offset, cdst + dst_offset); +} + +// rope == RoPE == rotary positional embedding static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) { const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x); @@ -758,6 +906,72 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c dst[i + 1] = x0*sin_theta + x1*cos_theta; } +static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { + const int col = blockDim.x*blockIdx.x + threadIdx.x; + const int row = blockDim.y*blockIdx.y + threadIdx.y; + + if (col >= ncols) { + return; + } + + const int i = row*ncols + col; + // dst[i] = col > n_past + row ? -INFINITY : x[i]; + dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU +} + +// the CUDA soft max implementation differs from the CPU implementation +// instead of doubles floats are used +// values are also not normalized to the maximum value by subtracting it in the exponential function +// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine +static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { + const int row = blockDim.y*blockIdx.y + threadIdx.y; + const int block_size = blockDim.x; + const int tid = threadIdx.x; + + float tmp = 0.0; + + for (int block_start = 0; block_start < ncols; block_start += block_size) { + const int col = block_start + tid; + + if (col >= ncols) { + break; + } + + const int i = row*ncols + col; + const float val = expf(x[i]); + tmp += val; + dst[i] = val; + } + + // sum up partial sums + __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + for (int block_start = 0; block_start < ncols; block_start += block_size) { + const int col = block_start + tid; + + if (col >= ncols) { + break; + } + + const int i = row*ncols + col; + dst[i] /= tmp; + } +} + +static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = scale * x[i]; +} + static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; add_f32<<>>(x, y, dst, k); @@ -831,73 +1045,92 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; + const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + <<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; + const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + <<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; + const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + <<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; + const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + <<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; + const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + <<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<(nrows + ny - 1)/ny, block_dims, 0, stream>>>(vx, y, dst, ncols); + dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(32, 2, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<>>(vx, y, dst, ncols); + const int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(32, 2, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<>>(vx, y, dst, ncols); + const int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(32, 2, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<>>(vx, y, dst, ncols); + const int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(32, 2, 1); - dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<>>(vx, y, dst, ncols); + const int ny = 2; + const int block_num_y = (nrows + ny - 1) / ny; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(32, ny, 1); + dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<>>(vx, y, dst, ncols, nrows); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -907,10 +1140,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; + const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec<1, 1, convert_f16> - <<>>(vx, y, dst, ncols); + <<>>(vx, y, dst, ncols, nrows); } static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { @@ -942,6 +1176,47 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } } +static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) { + const dim3 block_nums(1, nrows_x, nchannels_x); + const dim3 block_dims(WARP_SIZE, 1, 1); + mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x); +} + +static void ggml_mul_mat_vec_nc_f16_f32_cuda( + const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, + const int nchannels_x, const int channel_stride_x, cudaStream_t stream) { + + const dim3 block_nums(1, nrows_x, nchannels_x); + const dim3 block_dims(WARP_SIZE, 1, 1); + mul_mat_vec_nc_f16_f32<<>> + (vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x); +} + +static void ggml_cpy_f32_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_f32_f16<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void ggml_cpy_f32_f16_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, + const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_f32_f16<<>> + (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); +} + +static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; + scale_f32<<>>(x, dst, scale, k); +} + static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(nrows % 2 == 0); const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1); @@ -950,6 +1225,19 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i rope_f32<<>>(x, dst, ncols, p, theta_scale); } +static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { + const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1); + const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; + const dim3 block_nums(block_num_x, nrows_x, 1); + diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); +} + +static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(1, nrows_x, 1); + soft_max_f32<<>>(x, dst, ncols_x); +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 256 @@ -1120,10 +1408,25 @@ void ggml_cuda_host_free(void * ptr) { CUDA_CHECK(cudaFreeHost(ptr)); } -static cudaError_t ggml_cuda_h2d_tensor_2d( +static cudaError_t ggml_cuda_cpy_tensor_2d( void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { - char * dst_char = (char *) dst; + cudaMemcpyKind kind; + char * src_ptr; + if (src->backend == GGML_BACKEND_CPU) { + kind = cudaMemcpyHostToDevice; + src_ptr = (char *) src->data; + } else if (src->backend == GGML_BACKEND_GPU) { + kind = cudaMemcpyDeviceToDevice; + struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra; + int id; + CUDA_CHECK(cudaGetDevice(&id)); + src_ptr = (char *) extra->data_device[id]; + } else { + GGML_ASSERT(false); + } + char * dst_ptr = (char *) dst; + const int64_t ne0 = src->ne[0]; const int64_t nb0 = src->nb[0]; const int64_t nb1 = src->nb[1]; @@ -1134,17 +1437,17 @@ static cudaError_t ggml_cuda_h2d_tensor_2d( const int64_t bs = ggml_blck_size(type); int64_t i1_diff = i1_high - i1_low; - const void * x = (const void *) ((const char *) src->data + i1_low*nb1 + i2*nb2 + i3*nb3); + const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; if (nb0 == ts && nb1 == ts*ne0/bs) { - return cudaMemcpyAsync(dst_char, x, i1_diff*nb1, cudaMemcpyHostToDevice, stream); + return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream); } else if (nb0 == ts) { - return cudaMemcpy2DAsync(dst_char, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyHostToDevice, stream); + return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream); } else { for (int64_t i1 = 0; i1 < i1_diff; i1++) { const void * rx = (const void *) ((const char *) x + i1*nb1); - void * rd = (void *) (dst_char + i1*ts*ne0/bs); + void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); // pretend the row is a matrix with cols=1 - cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream); + cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream); if (r != cudaSuccess) return r; } return cudaSuccess; @@ -1380,8 +1683,81 @@ inline void ggml_cuda_op_rope( (void) i1; } +inline void ggml_cuda_op_diag_mask_inf( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, + float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t & cudaStream_main){ + + GGML_ASSERT(src0_ddf_i != nullptr); + GGML_ASSERT(dst_ddf_i != nullptr); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t i01_diff = i01_high - i01_low; + + const int n_past = ((int32_t *) src1->data)[0]; + + // compute + diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main); + CUDA_CHECK(cudaGetLastError()); + + (void) dst; + (void) src0_ddq_i; + (void) src1_ddf_i; + (void) i02; + (void) i1; +} + +inline void ggml_cuda_op_soft_max( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, + float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t & cudaStream_main){ + + GGML_ASSERT(src0_ddf_i != nullptr); + GGML_ASSERT(dst_ddf_i != nullptr); + + const int64_t ne00 = src0->ne[0]; + const int64_t i01_diff = i01_high - i01_low; + + // compute + soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); + CUDA_CHECK(cudaGetLastError()); + + (void) src1; + (void) dst; + (void) src0_ddq_i; + (void) src1_ddf_i; + (void) i02; + (void) i1; +} + +inline void ggml_cuda_op_scale( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, + float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t & cudaStream_main){ + + GGML_ASSERT(src0_ddf_i != nullptr); + GGML_ASSERT(dst_ddf_i != nullptr); + + const float scale = ((float *) src1->data)[0]; + + const int64_t ne00 = src0->ne[0]; + const int64_t i01_diff = i01_high - i01_low; + + // compute + scale_f32_cuda(src0_ddf_i, dst_ddf_i, scale, ne00*i01_diff, cudaStream_main); + CUDA_CHECK(cudaGetLastError()); + + (void) src1; + (void) dst; + (void) src0_ddq_i; + (void) src1_ddf_i; + (void) i02; + (void) i1; +} + static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - ggml_cuda_op_t op, bool src0_needs_f32) { + ggml_cuda_op_t op, bool src0_needs_f32, bool flatten_rows) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; @@ -1404,21 +1780,27 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); // strides for iteration over dims 3 and 2 - const int64_t src0_stride = ne00 * ne01; - const int64_t src1_stride = ne10 * ne11; - const int64_t dst_stride = ne0 * ne1; - const int64_t num_iters = ne02 * ne03; + const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03; + const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1; + const int64_t src0_stride = ne00 * ne01 * stride_mod; + const int64_t src1_stride = ne10 * ne11 * stride_mod; + const int64_t dst_stride = ne0 * ne1 * stride_mod; const size_t src0_ts = ggml_type_size(src0->type); const size_t src0_bs = ggml_blck_size(src0->type); - struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; struct ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; - struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT; + const bool src0_is_contiguous = ggml_is_contiguous(src0); const bool src0_is_f32 = src0->type == GGML_TYPE_F32; + const bool src1_is_contiguous = use_src1 && ggml_is_contiguous(src1); + const bool src1_stays_on_host = use_src1 && ( + dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE); + const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); @@ -1427,13 +1809,13 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm char * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; - float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; + float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // asq = actual size quantized, asf = actual size float size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0}; size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0}; size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0}; - size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0}; + size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0}; for (int id = 0; id < g_device_count; ++id) { if (!split && id != g_main_device) { @@ -1446,9 +1828,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm int64_t row_low, row_high; if (split) { row_low = id == 0 ? 0 : nrows0*g_tensor_split[id]; - row_low -= row_low % GGML_CUDA_DMMV_Y; row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1]; - row_high -= row_high % GGML_CUDA_DMMV_Y; } else { row_low = 0; row_high = nrows0; @@ -1461,7 +1841,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm cudaSetDevice(id); - if (src0_on_device) { + if (src0_on_device && src0_is_contiguous) { if (src0_is_f32) { src0_ddf[id] = (float *) src0_extra->data_device[id]; } else { @@ -1479,8 +1859,8 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]); } - if (use_src1) { - if (src1_on_device) { + if (use_src1 && !src1_stays_on_host) { + if (src1_on_device && src1_is_contiguous) { src1_ddf[id] = (float *) src1_extra->data_device[id]; } else { src1_ddf[id] = (float *) ggml_cuda_pool_malloc(num_iters*src1_stride * sizeof(float), &src1_asf[id]); @@ -1493,26 +1873,32 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]); } - for (int64_t i03 = 0; i03 < ne03; i03++) { + const int64_t i03_max = flatten_rows ? 1 : ne03; + const int64_t i02_max = flatten_rows ? 1 : ne02; + const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01; + + for (int64_t i03 = 0; i03 < i03_max; i03++) { const int64_t i13 = i03 % ne13; - for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i02 = 0; i02 < i02_max; i02++) { const int64_t i12 = i02 % ne12; const int64_t i0 = i03*ne02 + i02; - const int64_t i0_offset_low = row_low/ne01; - const int64_t i0_offset_high = row_high/ne01; + + // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs + const int64_t i0_offset_low = row_low/rows_per_iter; + const int64_t i0_offset_high = row_high/rows_per_iter; int64_t i01_low = 0; - int64_t i01_high = ne01; + int64_t i01_high = rows_per_iter; if (split) { if (i0 < i0_offset_low || i0 > i0_offset_high) { continue; } if (i0 == i0_offset_low) { - i01_low = row_low % ne01; + i01_low = row_low % rows_per_iter; } if (i0 == i0_offset_high) { - i01_high = row_high % ne01; + i01_high = row_high % rows_per_iter; } } @@ -1521,7 +1907,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm // Removing both asserts results in i01_high becoming 0 which in turn results in garbage output. // The root cause seems to be a problem with i0_offset_high becoming 0 when it should always be >0 (for single GPU). GGML_ASSERT(i01_low == 0 || g_device_count > 1); - GGML_ASSERT(i01_high == ne01 || g_device_count > 1); + GGML_ASSERT(i01_high == rows_per_iter || g_device_count > 1); const int64_t i01_diff = i01_high - i01_low; if (i01_diff == 0) { @@ -1529,24 +1915,23 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } const int64_t i11 = i13*ne12 + i12; - cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS]; + cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS]; cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS]; - cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS]; + cudaEvent_t cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS]; // for split tensors the data begins at i0 == i0_offset_low char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs; float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride; float * src1_ddf_i = src1_ddf[id] + i11*src1_stride; - float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; + float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; // for split tensors the data pointer needs to be rounded down // to the bin edge for i03, i02 bins beyond the first if (i0 - i0_offset_low > 0) { + GGML_ASSERT(!flatten_rows); src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs; src0_ddf_i -= (row_low % ne01)*ne00; - } - if (i0 - i0_offset_low > 0) { - dst_ddf_i -= (row_low % ne0)*ne1; + dst_ddf_i -= (row_low % ne0)*ne1; } // the main device memory buffer can be on VRAM scratch, with space for all partial results @@ -1556,30 +1941,37 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } // copy src0, src1 to device if necessary - if (use_src1) { + if (use_src1 && !src1_stays_on_host) { if (src1->backend == GGML_BACKEND_CPU) { - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_memcpy_src1)); - } else if (src1->backend == GGML_BACKEND_GPU) { + GGML_ASSERT(!flatten_rows || nrows0 == ggml_nrows(src1)); + int64_t nrows1 = flatten_rows ? nrows0 : ne11; + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, nrows1, cudaStream_memcpy_src1)); + } else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) { if (id != g_main_device) { + GGML_ASSERT(!flatten_rows); float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device]; src1_ddf_i_source += i11*src1_stride; CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float), cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1)); } + } else if (src1_on_device && !src1_is_contiguous) { + GGML_ASSERT(!split); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_main)); } else { GGML_ASSERT(false); } } CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1)); - if (!src0_on_device) { + + if (!src0_on_device || !src0_is_contiguous) { if (src0_is_f32) { - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); } else { - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); } } - // convert src0 to f32 if it's necessary for the ggml_cuda_op + // convert src0 to f32 if it is necessary for the ggml_cuda_op if (src0_needs_f32 && !src0_is_f32) { to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main); CUDA_CHECK(cudaGetLastError()); @@ -1644,39 +2036,30 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true); } void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten } void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true); } void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true); } bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - GGML_ASSERT(src0->backend != GGML_BACKEND_GPU); const int64_t ne10 = src1->ne[0]; const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; - // if (strcmp(dst->name, "KQ") == 0 || strcmp(dst->name, "KQV") == 0) { - // fprintf(stderr, "(%ld, %ld, %ld, %ld) + (%ld, %ld, %ld, %ld) -> (%ld, %ld, %ld, %ld)\n", - // src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - // src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - // dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]); - // return false; - // } - // TODO: find the optimal values for these if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32 && @@ -1688,23 +2071,158 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te return false; } +void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ + GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); + GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation + GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + CUDA_CHECK(cudaSetDevice(g_main_device)); + cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0]; + + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + void * src0_ddq = src0_extra->data_device[g_main_device]; + + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main); + + CUDA_CHECK(cudaDeviceSynchronize()); +} + +void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ + GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)); + GGML_ASSERT(!ggml_is_permuted(src0)); + GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + const int64_t nb01 = src0->nb[1]; + const int64_t nb02 = src0->nb[2]; + + CUDA_CHECK(cudaSetDevice(g_main_device)); + cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0]; + + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + void * src0_ddq = src0_extra->data_device[g_main_device]; + + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + const int row_stride_x = nb01 / sizeof(half); + const int channel_stride_x = nb02 / sizeof(half); + + ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main); + + CUDA_CHECK(cudaDeviceSynchronize()); +} + void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - if (src0->type == GGML_TYPE_F32) { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true); + bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && + src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; + + if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + ggml_cuda_mul_mat_vec_p021(src0, src1, dst); + } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { + ggml_cuda_mul_mat_vec_nc(src0, src1, dst); + }else if (src0->type == GGML_TYPE_F32) { + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { - if (src1->ne[1] == 1) { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); + if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) { + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false); } else { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); } } else { GGML_ASSERT(false); } } +void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_scale, true, true); +} + +void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const int64_t ne = ggml_nelements(src0); + GGML_ASSERT(ne == ggml_nelements(src1)); + + GGML_ASSERT(src0->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); + + GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); + GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + GGML_ASSERT(src0->ne[3] == 1); + + const int64_t nb00 = src0->nb[0]; + const int64_t nb01 = src0->nb[1]; + const int64_t nb02 = src0->nb[2]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + GGML_ASSERT(src1->ne[3] == 1); + + const int64_t nb10 = src1->nb[0]; + const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; + + CUDA_CHECK(cudaSetDevice(g_main_device)); + cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device][0]; + + const struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + const struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + + char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; + char * src1_ddc = (char *) src1_extra->data_device[g_main_device]; + + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + ggml_cpy_f32_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, + ne10, ne11, nb10, nb11, nb12, cudaStream_main); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { + ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, + ne10, ne11, nb10, nb11, nb12, cudaStream_main); + } else { + GGML_ASSERT(false); + } + + CUDA_CHECK(cudaDeviceSynchronize()); + + (void) dst; +} + +void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_diag_mask_inf, true, true); +} + +void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_soft_max, true, true); +} + void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, false); // FIXME flatten changes results } void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -1718,10 +2236,9 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { const size_t nb1 = tensor->nb[1]; ggml_backend backend = tensor->backend; struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu; + memset(extra, 0, sizeof(*extra)); for (int id = 0; id < g_device_count; ++id) { - extra->data_device[id] = nullptr; - if (backend == GGML_BACKEND_GPU && id != g_main_device) { continue; } @@ -1734,10 +2251,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { row_high = nrows; } else if (backend == GGML_BACKEND_GPU_SPLIT) { row_low = id == 0 ? 0 : nrows*g_tensor_split[id]; - row_low -= row_low % GGML_CUDA_DMMV_Y; row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1]; - row_high -= row_high % GGML_CUDA_DMMV_Y; - GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); } else { GGML_ASSERT(false); } @@ -1781,45 +2295,76 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { delete extra; } -void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { - if (tensor->src0 != nullptr && tensor->src0->op == GGML_OP_RESHAPE) { - ggml_cuda_assign_buffers(tensor); +void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) { + if (scratch && g_scratch_size == 0) { + return; } - const size_t size = ggml_nbytes(tensor); - GGML_ASSERT(size <= g_scratch_size); - if (g_scratch_offset + size > g_scratch_size) { - g_scratch_offset = 0; + // recursively assign CUDA buffers until a compute tensor is found + if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) { + const ggml_op src0_op = tensor->src0->op; + if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) { + ggml_cuda_assign_buffers_impl(tensor->src0, scratch); + } + } + if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) { + ggml_cuda_assign_buffers_impl(tensor->src1, scratch); } tensor->backend = GGML_BACKEND_GPU; struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu; - bool inplace = tensor->src0 != nullptr && tensor->src0->data == tensor->data; + const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) || + tensor->op == GGML_OP_VIEW; + const size_t size = ggml_nbytes(tensor); CUDA_CHECK(cudaSetDevice(g_main_device)); if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) { struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra; - extra->data_device[g_main_device] = src0_extra->data_device; - GGML_ASSERT(false); - } else { + char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; + size_t offset = 0; + if (tensor->op == GGML_OP_VIEW) { + memcpy(&offset, tensor->opt[0]->data, sizeof(size_t)); + } + extra->data_device[g_main_device] = src0_ddc + offset; + } else if (tensor->op == GGML_OP_CPY) { + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src1->extra; + void * src1_ddv = src1_extra->data_device[g_main_device]; + extra->data_device[g_main_device] = src1_ddv; + } else if (scratch) { + GGML_ASSERT(size <= g_scratch_size); + if (g_scratch_offset + size > g_scratch_size) { + g_scratch_offset = 0; + } + char * data = (char *) g_scratch_buffer; if (data == nullptr) { CUDA_CHECK(cudaMalloc(&data, g_scratch_size)); g_scratch_buffer = data; } extra->data_device[g_main_device] = data + g_scratch_offset; + + g_scratch_offset += size; + + GGML_ASSERT(g_scratch_offset <= g_scratch_size); + } else { // allocate new buffers outside of scratch + void * data; + CUDA_CHECK(cudaMalloc(&data, size)); + CUDA_CHECK(cudaMemset(data, 0, size)); + extra->data_device[g_main_device] = data; } - // fprintf(stderr, "data=%p offset=%ld data_device=%p\n", data, g_scratch_offset, extra->data_device[0]); - g_scratch_offset += size; - // fprintf(stderr, "%s: scratch %d, %p - %p\n", - // tensor->name, g_scratch_index, data + g_scratch_offset, data + g_scratch_offset + size); - - GGML_ASSERT(g_scratch_offset <= g_scratch_size); tensor->extra = extra; } +void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, true); +} + +void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) { + ggml_cuda_assign_buffers_impl(tensor, false); +} + void ggml_cuda_set_main_device(int main_device) { if (main_device > g_device_count) { fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n", @@ -1838,6 +2383,15 @@ void ggml_cuda_set_scratch_size(size_t scratch_size) { g_scratch_size = scratch_size; } +void ggml_cuda_free_scratch() { + if (g_scratch_buffer == nullptr) { + return; + } + + CUDA_CHECK(cudaFree(g_scratch_buffer)); + g_scratch_buffer = nullptr; +} + bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){ ggml_cuda_func_t func; const bool any_on_device = tensor->backend == GGML_BACKEND_GPU @@ -1875,12 +2429,39 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ } func = ggml_cuda_mul_mat; break; + case GGML_OP_SCALE: + if (!any_on_device) { + return false; + } + func = ggml_cuda_scale; + break; + case GGML_OP_CPY: + if (!any_on_device) { + return false; + } + func = ggml_cuda_cpy; + break; case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: if (!any_on_device) { return false; } func = ggml_cuda_nop; break; + case GGML_OP_DIAG_MASK_INF: + if (!any_on_device) { + return false; + } + func = ggml_cuda_diag_mask_inf; + break; + case GGML_OP_SOFT_MAX: + if (!any_on_device) { + return false; + } + func = ggml_cuda_soft_max; + break; case GGML_OP_ROPE: if (!any_on_device) { return false; diff --git a/ggml-cuda.h b/ggml-cuda.h index fde6d4085..d32b44842 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -28,8 +28,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor); void ggml_cuda_free_data(struct ggml_tensor * tensor); void ggml_cuda_assign_buffers(struct ggml_tensor * tensor); +void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); void ggml_cuda_set_main_device(int main_device); void ggml_cuda_set_scratch_size(size_t scratch_size); +void ggml_cuda_free_scratch(void); bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor); #ifdef __cplusplus diff --git a/ggml.c b/ggml.c index 32c191307..c0efa1977 100644 --- a/ggml.c +++ b/ggml.c @@ -3939,6 +3939,12 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) { tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } +bool ggml_is_permuted(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; +} + static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); diff --git a/ggml.h b/ggml.h index f2a91761b..9b0c846f8 100644 --- a/ggml.h +++ b/ggml.h @@ -485,6 +485,7 @@ extern "C" { GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor); + GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor); // use this to compute the memory overhead of a tensor GGML_API size_t ggml_tensor_overhead(void); diff --git a/llama.cpp b/llama.cpp index d2a52bb0c..b8bc0d821 100644 --- a/llama.cpp +++ b/llama.cpp @@ -165,6 +165,11 @@ struct llama_kv_cache { if (ctx) { ggml_free(ctx); } + +#ifdef GGML_USE_CUBLAS + ggml_cuda_free_data(k); + ggml_cuda_free_data(v); +#endif // GGML_USE_CUBLAS } }; @@ -210,6 +215,7 @@ struct llama_model { for (size_t i = 0; i < tensors_by_name.size(); ++i) { ggml_cuda_free_data(tensors_by_name[i].second); } + ggml_cuda_free_scratch(); #elif defined(GGML_USE_CLBLAST) for (size_t i = 0; i < tensors_by_name.size(); ++i) { ggml_cl_free_data(tensors_by_name[i].second); @@ -867,7 +873,8 @@ static bool kv_cache_init( const struct llama_hparams & hparams, struct llama_kv_cache & cache, ggml_type wtype, - int n_ctx) { + int n_ctx, + int n_gpu_layers) { const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; @@ -893,6 +900,15 @@ static bool kv_cache_init( ggml_set_name(cache.k, "cache_k"); ggml_set_name(cache.v, "cache_v"); +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer + 1) { + ggml_cuda_assign_buffers_no_scratch(cache.v); + } + if (n_gpu_layers > n_layer + 2) { + ggml_cuda_assign_buffers_no_scratch(cache.k); + } +#endif // GGML_USE_CUBLAS + return true; } @@ -903,6 +919,7 @@ struct llama_context_params llama_context_default_params() { /*.gpu_layers =*/ 0, /*.main_gpu =*/ 0, /*.tensor_split =*/ {0}, + /*.low_vram =*/ false, /*.seed =*/ -1, /*.f16_kv =*/ true, /*.logits_all =*/ false, @@ -1011,6 +1028,7 @@ static void llama_model_load_internal( int n_gpu_layers, int main_gpu, const float * tensor_split, + bool low_vram, ggml_type memory_type, bool use_mmap, bool use_mlock, @@ -1137,18 +1155,34 @@ static void llama_model_load_internal( ml->ggml_ctx = ctx; model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU); - model.norm = ml->get_tensor("norm.weight", {n_embd}, GGML_BACKEND_CPU); // "output" tensor { + ggml_backend backend_norm; ggml_backend backend_output; if (n_gpu_layers > int(n_layer)) { // NOLINT + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#else + backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; } else { + backend_norm = GGML_BACKEND_CPU; backend_output = GGML_BACKEND_CPU; } + model.norm = ml->get_tensor("norm.weight", {n_embd}, backend_norm); model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } } const int i_gpu_start = n_layer - n_gpu_layers; @@ -1208,22 +1242,47 @@ static void llama_model_load_internal( (void) vram_scratch; (void) n_batch; #ifdef GGML_USE_CUBLAS - vram_scratch = n_batch * MB; - ggml_cuda_set_scratch_size(vram_scratch); - if (n_gpu_layers > 0) { - fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n", - __func__, vram_scratch / MB); + if (low_vram) { + fprintf(stderr, "%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__); + ggml_cuda_set_scratch_size(0); // disable scratch + } else { + vram_scratch = n_batch * MB; + ggml_cuda_set_scratch_size(vram_scratch); + if (n_gpu_layers > 0) { + fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n", + __func__, vram_scratch / MB); + } } #endif // GGML_USE_CUBLAS #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); - fprintf(stderr, "%s: offloading %d layers to GPU\n", __func__, n_gpu); + fprintf(stderr, "%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); if (n_gpu_layers > (int) hparams.n_layer) { - fprintf(stderr, "%s: offloading output layer to GPU\n", __func__); + fprintf(stderr, "%s: offloading non-repeating layers to GPU\n", __func__); } + size_t vram_kv_cache = 0; + if (n_gpu_layers > (int) hparams.n_layer + 1) { + if (low_vram) { + fprintf(stderr, "%s: cannot offload v cache to GPU due to low VRAM option\n", __func__); + } else { + fprintf(stderr, "%s: offloading v cache to GPU\n", __func__); + vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; + } + } + if (n_gpu_layers > (int) hparams.n_layer + 2) { + if (low_vram) { + fprintf(stderr, "%s: cannot offload k cache to GPU due to low VRAM option\n", __func__); + } else { + fprintf(stderr, "%s: offloading k cache to GPU\n", __func__); + vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; + } + } + const int max_offloadable_layers = low_vram ? hparams.n_layer + 1 : hparams.n_layer + 3; + fprintf(stderr, "%s: offloaded %d/%d layers to GPU\n", + __func__, std::min(n_gpu_layers, max_offloadable_layers), hparams.n_layer + 3); fprintf(stderr, "%s: total VRAM used: %zu MB\n", - __func__, (vram_weights + vram_scratch + MB - 1) / MB); // round up + __func__, (vram_weights + vram_scratch + vram_kv_cache + MB - 1) / MB); // round up #else (void) n_gpu_layers; #endif @@ -1262,6 +1321,7 @@ static bool llama_model_load( int n_gpu_layers, int main_gpu, float * tensor_split, + bool low_vram, ggml_type memory_type, bool use_mmap, bool use_mlock, @@ -1269,7 +1329,7 @@ static bool llama_model_load( llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, lctx, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, memory_type, + llama_model_load_internal(fname, lctx, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, low_vram, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); return true; } catch (const std::exception & err) { @@ -1345,12 +1405,33 @@ static bool llama_eval_internal( const int i_gpu_start = n_layer - n_gpu_layers; (void) i_gpu_start; + // offload functions set the tensor output backend to GPU + // tensors are GPU-accelerated if any input or the output has been offloaded + // + // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal + // in that case ggml_cuda_assign_buffers has no effect + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers; + } +#endif // GGML_USE_CUBLAS + for (int il = 0; il < n_layer; ++il) { offload_func_t offload_func = llama_nop; #ifdef GGML_USE_CUBLAS if (il >= i_gpu_start) { - offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU + offload_func = ggml_cuda_assign_buffers; } #endif // GGML_USE_CUBLAS @@ -1373,31 +1454,42 @@ static bool llama_eval_internal( // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - // offload_func(tmpq); - ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - // offload_func(tmpk); + offload_func_kq(tmpk); ggml_set_name(tmpk, "tmpk"); + struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + offload_func_kq(tmpq); + ggml_set_name(tmpq, "tmpq"); + struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0); + offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0); + offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); // store key and value to memory { // compute the transposed [N, n_embd] V matrix - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N)); + + struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + offload_func_v(tmpv); + ggml_set_name(tmpv, "tmpv"); + + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd, N)); + offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); + offload_func_kq(k); ggml_set_name(k, "k"); + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, ( n_ctx)*ggml_element_size(kv_self.v), (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); + offload_func_v(v); ggml_set_name(v, "v"); // important: storing RoPE-ed version of K in the KV cache! @@ -1409,6 +1501,7 @@ static bool llama_eval_internal( ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + offload_func_kq(Q); ggml_set_name(Q, "Q"); struct ggml_tensor * K = @@ -1417,10 +1510,12 @@ static bool llama_eval_internal( ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), n_embd/n_head, n_head, n_past + N), 0, 2, 1, 3); + offload_func_kq(K); ggml_set_name(K, "K"); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); // KQ_scaled = KQ / sqrt(n_embd/n_head) @@ -1429,14 +1524,17 @@ static bool llama_eval_internal( // KQ_scaled shape [n_past + N, N, n_head, 1] struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); // KQ_masked = mask_past(KQ_scaled) struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); // split cached V into n_head heads @@ -1446,10 +1544,12 @@ static bool llama_eval_internal( n_ctx*ggml_element_size(kv_self.v), n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head, il*n_ctx*ggml_element_size(kv_self.v)*n_embd); + offload_func_v(V); ggml_set_name(V, "V"); #if 1 struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + offload_func_v(KQV); ggml_set_name(KQV, "KQV"); #else // make V contiguous in memory to speed up the matmul, however we waste time on the copy @@ -1461,12 +1561,14 @@ static bool llama_eval_internal( // KQV_merged = KQV.permute(0, 2, 1, 3) struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + offload_func_v(KQV_merged); ggml_set_name(KQV_merged, "KQV_merged"); // cur = KQV_merged.contiguous().view(n_embd, N) cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + offload_func_v(cur); ggml_set_name(cur, "KQV_merged_contiguous"); // projection (no bias) @@ -1478,7 +1580,6 @@ static bool llama_eval_internal( } lctx.use_buf(ctx0, 1); - //ggml_cuda_set_scratch(1); struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); offload_func(inpFF); @@ -1536,32 +1637,24 @@ static bool llama_eval_internal( } lctx.use_buf(ctx0, 0); - //ggml_cuda_set_scratch(0); // used at the end to optionally extract the embeddings struct ggml_tensor * embeddings = NULL; - offload_func_t offload_func = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) { - offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU - } -#endif // GGML_USE_CUBLAS // norm { cur = ggml_rms_norm(ctx0, inpL); - offload_func(cur); + offload_func_nr(cur); ggml_set_name(cur, "rms_norm_inpL"); cur = ggml_rms_norm(ctx0, cur); - offload_func(cur); + offload_func_nr(cur); ggml_set_name(cur, "rms_norm_after"); // cur = cur*norm(broadcasted) cur = ggml_mul(ctx0, cur, model.norm); - offload_func(cur); + offload_func_nr(cur); ggml_set_name(cur, "result_norm"); embeddings = cur; @@ -2552,8 +2645,8 @@ struct llama_context * llama_init_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_batch, params.n_gpu_layers, - params.main_gpu, params.tensor_split, memory_type, params.use_mmap, params.use_mlock, + if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_batch, params.n_gpu_layers, params.main_gpu, + params.tensor_split, params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { fprintf(stderr, "%s: failed to load model\n", __func__); llama_free(ctx); @@ -2562,7 +2655,7 @@ struct llama_context * llama_init_from_file( // reserve memory for context buffers if (!params.vocab_only) { - if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) { + if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; diff --git a/llama.h b/llama.h index 61f6c867d..64292265c 100644 --- a/llama.h +++ b/llama.h @@ -77,6 +77,7 @@ extern "C" { int n_gpu_layers; // number of layers to store in VRAM int main_gpu; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs + bool low_vram; // if true, reduce VRAM usage at the cost of performance int seed; // RNG seed, -1 for random bool f16_kv; // use fp16 for KV cache