From 1fcdcc28b119a6608774d52de905931bd5f8a43d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 25 May 2023 23:07:29 +0200 Subject: [PATCH] cuda : performance optimizations (#1530) * xor hack * block y dim * loop unrolling * Fixed cmake LLAMA_CUDA_BY option * Removed hipblas compatibility code * Define GGML_CUDA_DMMV_BLOCK_Y if not defined * Fewer iters, more ops per iter * Renamed DMMV X/Y compilation options --- CMakeLists.txt | 52 ++++++++++++----------- Makefile | 12 +++++- ggml-cuda.cu | 110 +++++++++++++++++++++++++++++++------------------ 3 files changed, 110 insertions(+), 64 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 39db2e3fc..31c5bd91d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,42 +37,44 @@ endif() # # general -option(LLAMA_STATIC "llama: static link libraries" OFF) -option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) -option(LLAMA_LTO "llama: enable link time optimization" OFF) +option(LLAMA_STATIC "llama: static link libraries" OFF) +option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) +option(LLAMA_LTO "llama: enable link time optimization" OFF) # debug -option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) -option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) -option(LLAMA_GPROF "llama: enable gprof" OFF) +option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) +option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) +option(LLAMA_GPROF "llama: enable gprof" OFF) # sanitizers -option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) -option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) -option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) +option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) +option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) +option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) # instruction set specific -option(LLAMA_AVX "llama: enable AVX" ON) -option(LLAMA_AVX2 "llama: enable AVX2" ON) -option(LLAMA_AVX512 "llama: enable AVX512" OFF) -option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) -option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) -option(LLAMA_FMA "llama: enable FMA" ON) +option(LLAMA_AVX "llama: enable AVX" ON) +option(LLAMA_AVX2 "llama: enable AVX2" ON) +option(LLAMA_AVX512 "llama: enable AVX512" OFF) +option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) +option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) +option(LLAMA_FMA "llama: enable FMA" ON) # in MSVC F16C is implied with AVX2/AVX512 if (NOT MSVC) - option(LLAMA_F16C "llama: enable F16C" ON) + option(LLAMA_F16C "llama: enable F16C" ON) endif() # 3rd party libs -option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) -option(LLAMA_BLAS "llama: use BLAS" OFF) -option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) -option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) -option(LLAMA_CLBLAST "llama: use CLBlast" OFF) +option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) +option(LLAMA_BLAS "llama: use BLAS" OFF) +option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) +option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") +set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") +option(LLAMA_CLBLAST "llama: use CLBlast" OFF) -option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) -option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) -option(LLAMA_BUILD_SERVER "llama: build server example" OFF) +option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_SERVER "llama: build server example" OFF) # # Build info header @@ -184,6 +186,8 @@ if (LLAMA_CUBLAS) set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) + add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) + add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y}) if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) diff --git a/Makefile b/Makefile index 08e250314..804307b53 100644 --- a/Makefile +++ b/Makefile @@ -133,9 +133,19 @@ ifdef LLAMA_CUBLAS OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native +ifdef LLAMA_CUDA_DMMV_X + NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) +else + NVCCFLAGS += -DGGML_CUDA_DMMV_X=32 +endif # LLAMA_CUDA_DMMV_X +ifdef LLAMA_CUDA_DMMV_Y + NVCCFLAGS += -DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y) +else + NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1 +endif # LLAMA_CUDA_DMMV_Y ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ -endif +endif # LLAMA_CUBLAS ifdef LLAMA_CLBLAST CFLAGS += -DGGML_USE_CLBLAST CXXFLAGS += -DGGML_USE_CLBLAST diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 35d2e457c..98170a3ae 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -83,9 +83,19 @@ typedef struct { } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); +#define WARP_SIZE 32 + #define CUDA_MUL_BLOCK_SIZE 256 + #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 -#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec + +// dmmv = dequantize_mul_mat_vec +#ifndef GGML_CUDA_DMMV_X +#define GGML_CUDA_DMMV_X 32 +#endif +#ifndef GGML_CUDA_DMMV_Y +#define GGML_CUDA_DMMV_Y 1 +#endif static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -200,41 +210,51 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) dequantize_kernel(vx, ib, iqs, v0, v1); } -template +template static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { - const int row = blockIdx.x; + // qk = quantized weights per x block + // qr = number of quantized weights per data value in x block + const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; + const int iter_stride = 2*GGML_CUDA_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter const int y_offset = qr == 1 ? 1 : qk/2; - __shared__ float tmp[block_size]; // separate sum for each thread - tmp[tid] = 0; + float tmp = 0; // partial sum for thread in warp - for (int i = 0; i < ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*ncols + col)/qk; // block index - const int iqs = (col%qk)/qr; // quant index + for (int i = 0; i < ncols; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/qk; // x block index + const int iqs = (col%qk)/qr; // x quant index const int iybs = col - col%qk; // y block start index - // dequantize - float v0, v1; - dequantize_kernel(vx, ib, iqs, v0, v1); +// processing >2 values per i iter is faster for fast GPUs +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + // process 2 vals per j iter - // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; + // dequantize + float v0, v1; + dequantize_kernel(vx, ib, iqs + j/qr, v0, v1); + // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val + + // matrix multiplication + tmp += v0 * y[iybs + iqs + j/qr + 0]; + tmp += v1 * y[iybs + iqs + j/qr + y_offset]; + // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 + } } // sum up partial sums and write back result __syncthreads(); - for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - __syncthreads(); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } + if (tid == 0) { - dst[row] = tmp[0]; + dst[row] = tmp; } } @@ -269,33 +289,43 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu } static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -304,9 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c } static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols); + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + dequantize_mul_mat_vec<1, 1, convert_f16> + <<>>(vx, y, dst, ncols); } static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {