diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 08428ea3f..79e2d313a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -14,9 +14,11 @@ // for rocblas_initialize() #include "rocblas/rocblas.h" #endif // __HIP_PLATFORM_AMD__ +#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N #define CUBLAS_OP_T HIPBLAS_OP_T #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS @@ -235,8 +237,12 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment } +template +using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream); +typedef to_t_cuda_t to_fp32_cuda_t; +typedef to_t_cuda_t to_fp16_cuda_t; + typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); -typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream); typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v); typedef void (*cpy_kernel_t)(const char * cx, char * cdst); typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); @@ -1515,6 +1521,14 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, v.y = x[ib + iqs + 1]; } +static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){ + const float * x = (const float *) vx; + + // automatic half -> float type cast if dfloat == float + v.x = x[ib + iqs + 0]; + v.y = x[ib + iqs + 1]; +} + static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { const int ix = blockDim.x*blockIdx.x + threadIdx.x; @@ -1554,8 +1568,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } -template -static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) { +template +static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; if (i >= k) { @@ -4826,6 +4840,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c dequantize_block<1, 1, convert_f16><<>>(vx, y, k); } +static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + dequantize_block<1, 1, convert_f32><<>>(vx, y, k); +} + static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; @@ -4835,6 +4854,15 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa <<>>(vx, y, dst, ncols, nrows); } +static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_fp32_to_fp16_cuda; + default: + return nullptr; + } +} + static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: @@ -6016,8 +6044,6 @@ inline void ggml_cuda_op_mul_mat_cublas( GGML_ASSERT(src1_ddf_i != nullptr); GGML_ASSERT(dst_dd_i != nullptr); - const float alpha = 1.0f; - const float beta = 0.0f; const int64_t ne00 = src0->ne[0]; @@ -6026,16 +6052,6 @@ inline void ggml_cuda_op_mul_mat_cublas( const int64_t ne0 = dst->ne[0]; const int64_t row_diff = row_high - row_low; - float * src0_ddq_as_f32; - size_t src0_as = 0; - - if (src0->type != GGML_TYPE_F32) { - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); - src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT - to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); - } - const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; - int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -6043,16 +6059,72 @@ inline void ggml_cuda_op_mul_mat_cublas( // ldc == nrows of the matrix that cuBLAS writes into int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); - CUBLAS_CHECK( - cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, - row_diff, src1_ncols, ne10, - &alpha, src0_ddf_i, ne00, - src1_ddf_i, ne10, - &beta, dst_dd_i, ldc)); + const int compute_capability = g_compute_capabilities[id]; - if (src0_as > 0) { - ggml_cuda_pool_free(src0_ddq_as_f32, src0_as); + if (compute_capability >= CC_TURING && src0->type == GGML_TYPE_F16 && ggml_is_contiguous(src0) && ldc == row_diff) { + // convert src1 to fp16, multiply as fp16, convert dst to fp32 + half * src1_as_f16 = nullptr; + size_t src1_as = 0; + if (src1->type != GGML_TYPE_F16) { + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + size_t ne = src1_ncols*ne10; + src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); + to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); + } + const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; + + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); + CUBLAS_CHECK( + cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f16, src0_dd_i, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16, CUDA_R_16F, ldc, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); + + ggml_cuda_pool_free(dst_f16, dst_as); + + if (src1_as != 0) { + ggml_cuda_pool_free(src1_as_f16, src1_as); + } + } + else { + float * src0_ddq_as_f32 = nullptr; + size_t src0_as = 0; + + if (src0->type != GGML_TYPE_F32) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); + GGML_ASSERT(to_fp32_cuda != nullptr); + src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT + to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); + } + const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; + + const float alpha = 1.0f; + const float beta = 0.0f; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); + CUBLAS_CHECK( + cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha, src0_ddf_i, ne00, + src1_ddf_i, ne10, + &beta, dst_dd_i, ldc)); + + if (src0_as != 0) { + ggml_cuda_pool_free(src0_ddq_as_f32, src0_as); + } } (void) dst;