diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5053757e6..96976f248 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -150,8 +150,8 @@ #define CUDA_USE_TENSOR_CORES #endif -// max batch size to use MMQ kernels when tensor cores are available -#define MMQ_MAX_BATCH_SIZE 32 +#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels +#define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 @@ -5310,51 +5310,59 @@ template static __global__ void #endif // __CUDA_ARCH__ >= CC_VOLTA } -#define MMVQ_NWARPS_NVIDIA 4 -#define MMVQ_NWARPS_AMD_RDNA2 1 -#define MMVQ_NWARPS_AMD_OLD 4 - -template +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants +// tell the compiler to use as many registers as it wants, see nwarps definition below +__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) { + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { - const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par; +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) + constexpr int nwarps = 1; + constexpr int rows_per_cuda_block = 1; +#else + constexpr int nwarps = ncols_y <= 4 ? 4 : 2; + constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) - const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - const int row = blockIdx.x; - - const int blocks_per_row_x = ncols_x / qk; - const int blocks_per_col_y = nrows_y / QK8_1; - const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + const int row0 = rows_per_cuda_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi; // partial sum for each thread - float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f}; + float tmp[ncols_y][rows_per_cuda_block] = {0.0f}; const block_q_t * x = (const block_q_t *) vx; const block_q8_1 * y = (const block_q8_1 *) vy; - for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) { - const int ibx = row*blocks_per_row_x + i; // x block index + for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx - const int iby = i * (qk/QK8_1); // y block index that aligns with ibx - - const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int + // x block quant index when casting the quants to int + const int kqs = vdr * (tid % (qi/vdr)); #pragma unroll for (int j = 0; j < ncols_y; ++j) { - tmp[j] += vec_dot_q_cuda(&x[ibx], &y[j*blocks_per_col_y + iby], iqs); +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp[j][i] += vec_dot_q_cuda( + &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs); + } } } - __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE]; + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE]; if (threadIdx.y > 0) { #pragma unroll for (int j = 0; j < ncols_y; ++j) { - tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j]; +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; + } } } __syncthreads(); @@ -5366,13 +5374,16 @@ static __global__ void mul_mat_vec_q( #pragma unroll for (int j = 0; j < ncols_y; ++j) { #pragma unroll - for (int i = 0; i < nwarps-1; ++i) { - tmp[j] += tmp_shared[i][j][threadIdx.x]; + for (int i = 0; i < rows_per_cuda_block; ++i) { +#pragma unroll + for (int l = 0; l < nwarps-1; ++l) { + tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; + } + tmp[j][i] = warp_reduce_sum(tmp[j][i]); } - tmp[j] = warp_reduce_sum(tmp[j]); - if (threadIdx.x == 0) { - dst[j*nrows_dst + row] = tmp[j]; + if (threadIdx.x < rows_per_cuda_block) { + dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; } } } @@ -6851,65 +6862,75 @@ static void mul_mat_vec_q_cuda( const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { GGML_ASSERT(ncols_x % qk == 0); - GGML_ASSERT(ncols_y <= 4); + GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); int id; CUDA_CHECK(cudaGetDevice(&id)); - int nwarps; - if (g_device_caps[id].cc >= CC_OFFSET_AMD) { - nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD; - } else { - nwarps = MMVQ_NWARPS_NVIDIA; - } + int64_t nwarps = 1; + int64_t rows_per_cuda_block = 1; - const dim3 block_nums(nrows_x, 1, 1); + if (g_device_caps[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 + switch(ncols_y) { + case 1: + nwarps = 4; + rows_per_cuda_block = 1; + break; + case 2: + case 3: + case 4: + nwarps = 4; + rows_per_cuda_block = 2; + break; + case 5: + case 6: + case 7: + case 8: + nwarps = 2; + rows_per_cuda_block = 2; + break; + default: + GGML_ASSERT(false); + break; + } + } + const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; + const dim3 block_nums(nblocks, 1, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); - switch (nwarps) { - case 1: switch(ncols_y) { - case 1: - mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - case 2: - mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - case 3: - mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - case 4: - mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - default: - GGML_ASSERT(false); - break; - } break; - case 4: switch(ncols_y) { - case 1: - mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - case 2: - mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - case 3: - mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - case 4: - mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - default: - GGML_ASSERT(false); - break; - } break; - + switch (ncols_y) { + case 1: + mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 2: + mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 3: + mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 4: + mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 5: + mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 6: + mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 7: + mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; + case 8: + mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + break; default: GGML_ASSERT(false); break; @@ -9735,7 +9756,7 @@ static __global__ void k_compute_batched_ptrs( ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; } -static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_mul_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); @@ -9893,39 +9914,69 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 int64_t min_compute_capability = INT_MAX; + bool any_pascal_with_slow_fp16 = false; if (split) { ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; auto & tensor_split = buft_ctx->tensor_split; for (int id = 0; id < g_device_count; ++id) { - if (min_compute_capability > g_device_caps[id].cc && tensor_split[id] < (id + 1 < g_device_count ? tensor_split[id + 1] : 1.0f)) { + // skip devices that are not going to do any work: + if (tensor_split[id] >= (id + 1 < g_device_count ? tensor_split[id + 1] : 1.0f)) { + continue; + } + + if (min_compute_capability > g_device_caps[id].cc) { min_compute_capability = g_device_caps[id].cc; } + if (g_device_caps[id].cc == 610) { + any_pascal_with_slow_fp16 = true; + } } } else { - min_compute_capability = g_device_caps[g_main_device].cc; + min_compute_capability = g_device_caps[g_main_device].cc; + any_pascal_with_slow_fp16 = g_device_caps[g_main_device].cc == 610; } + // check data types and tensor shapes for custom matrix multiplication kernels: + bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1; + + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + + bool use_mul_mat_q = ggml_cuda_supports_mmq(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) const bool fp16_performance_good = min_compute_capability >= CC_RDNA1; - bool use_mul_mat_q = ggml_is_quantized(src0->type); + #ifdef CUDA_USE_TENSOR_CORES use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3; #endif // CUDA_USE_TENSOR_CORES #else - const bool fp16_performance_good = min_compute_capability >= CC_VOLTA; - bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); + // fp16 performance is good on Volta or newer and on P100 (compute capability 6.0) + const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16; + + // mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1 + use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A; + use_mul_mat_q = use_mul_mat_q && min_compute_capability >= MIN_CC_DP4A; + #ifdef CUDA_USE_TENSOR_CORES // when tensor cores are available, use them for large batch size // ref: https://github.com/ggerganov/llama.cpp/pull/3776 - use_mul_mat_q = use_mul_mat_q && !(fp16_performance_good && src1->ne[1] > MMQ_MAX_BATCH_SIZE); + use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE); #endif // CUDA_USE_TENSOR_CORES #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - use_mul_mat_q = use_mul_mat_q && ggml_cuda_supports_mmq(src0->type); + // if mmvq is available it's a better choice than dmmv: +#ifndef GGML_CUDA_FORCE_DMMV + use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; +#endif // GGML_CUDA_FORCE_DMMV // debug helpers //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); @@ -9943,33 +9994,15 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 ggml_cuda_mul_mat_vec_nc(src0, src1, dst); } else if (!split && all_on_device && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch - ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); - } else if (src0->type == GGML_TYPE_F32) { - ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); - } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { - if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->type == GGML_TYPE_F32) { -#ifdef GGML_CUDA_FORCE_DMMV - const bool use_mul_mat_vec_q = false; -#else - const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); -#endif // GGML_CUDA_FORCE_DMMV - - if (use_mul_mat_vec_q) { - ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); - } else { - ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); - } - } else { - if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32) { - ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); - } else if (use_mul_mat_q) { - ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true); - } else { - ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); - } - } + ggml_cuda_mul_mat_batched_cublas(src0, src1, dst); + } else if (use_dequantize_mul_mat_vec) { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); + } else if (use_mul_mat_vec_q) { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); + } else if (use_mul_mat_q) { + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true); } else { - GGML_ASSERT(false); + ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } }