From 6b1cf541971732af2da15b9a2f645dabd1c24b4b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 Dec 2023 13:32:54 +0200 Subject: [PATCH] sync : ggml (part 2, CUDA) --- ggml-cuda.cu | 1404 ++++++++++++++++++++++++++++++++++++-------------- ggml-cuda.h | 10 +- 2 files changed, 1024 insertions(+), 390 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1200d1c88..51ff390e3 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,13 +1,12 @@ #include -#include #include #include +#include #include #include #include #include #include -#include #if defined(GGML_USE_HIPBLAS) #include @@ -70,6 +69,7 @@ #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize #define cudaSetDevice hipSetDevice #define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamFireAndForget hipStreamFireAndForget #define cudaStreamNonBlocking hipStreamNonBlocking #define cudaStreamSynchronize hipStreamSynchronize #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) @@ -191,7 +191,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ cudaGetErrorString(err_)); \ fprintf(stderr, "current device: %d\n", id); \ - exit(1); \ + GGML_ASSERT(!"CUDA error"); \ } \ } while (0) @@ -205,7 +205,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \ err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \ fprintf(stderr, "current device: %d\n", id); \ - exit(1); \ + GGML_ASSERT(!"cuBLAS error"); \ } \ } while (0) #else @@ -217,7 +217,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); cudaGetDevice(&id); \ fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ fprintf(stderr, "current device: %d\n", id); \ - exit(1); \ + GGML_ASSERT(!"cuBLAS error"); \ } \ } while (0) #endif // CUDART_VERSION >= 11 @@ -434,8 +434,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define WARP_SIZE 32 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses -#define CUDA_ADD_BLOCK_SIZE 256 -#define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_RELU_BLOCK_SIZE 256 @@ -528,40 +526,87 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } -static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= kx) { - return; - } - dst[i] = x[i] + y[i%ky]; +static __device__ __forceinline__ float op_repeat(const float a, const float b) { + return b; } -static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = __hadd(x[i], __float2half(y[i])); +static __device__ __forceinline__ float op_add(const float a, const float b) { + return a + b; } -static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= k) { - return; - } - dst[i] = __half2float(x[i]) + y[i]; +static __device__ __forceinline__ float op_mul(const float a, const float b) { + return a * b; } -static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; +static __device__ __forceinline__ float op_div(const float a, const float b) { + return a / b; +} - if (i >= kx) { +template +static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s10,*/ int s11, int s12, int s13) { + const int i0s = blockDim.x*blockIdx.x + threadIdx.x; + const int i1 = (blockDim.y*blockIdx.y + threadIdx.y); + const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3; + const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3; + + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { return; } - dst[i] = x[i] * y[i%ky]; + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i_src0; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) { + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + } +} + +template +static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s10,*/ int s11, int s12, int s13) { + + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + const int i3 = i/(ne2*ne1*ne0); + const int i2 = (i/(ne1*ne0)) % ne2; + const int i1 = (i/ne0) % ne1; + const int i0 = i % ne0; + + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i_src0; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); } static __global__ void gelu_f32(const float * x, float * dst, const int k) { @@ -605,12 +650,10 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) { } template -static __global__ void norm_f32(const float * x, float * dst, const int ncols) { +static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - const float eps = 1e-5f; - float2 mean_var = make_float2(0.f, 0.f); for (int col = tid; col < ncols; col += block_size) { @@ -4730,8 +4773,8 @@ static __global__ void rope( template static __global__ void rope_neox( - const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims + const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, + float ext_factor, float attn_factor, rope_corr_dims corr_dims ) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -4740,25 +4783,23 @@ static __global__ void rope_neox( } const int row = blockDim.x*blockIdx.x + threadIdx.x; - const int ib = col / n_dims; - const int ic = col % n_dims; - - const int i = row*ncols + ib*n_dims + ic/2; + const int i = row*ncols + col/2; const int i2 = row/p_delta_rows; - float cur_rot = inv_ndims * ic - ib; + // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero + const float cur_rot = -float(col)/ncols; const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); + const float theta_base = p*powf(freq_base, cur_rot); float cos_theta, sin_theta; rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; - const float x1 = x[i + n_dims/2]; + const float x1 = x[i + ncols/2]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; } static __global__ void rope_glm_f32( @@ -4824,6 +4865,65 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols, dst[i] = col * m_k + x[i]; } +static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) { + const int row = blockIdx.y; + const int col = threadIdx.x; + + float sum = 0.0f; + for (int i = col; i < ncols; i += blockDim.x) { + sum += x[row * ncols + i]; + } + + sum = warp_reduce_sum(sum); + + if (col == 0) { + dst[row] = sum; + } +} + +template +static inline __device__ void swap(T & a, T & b) { + T tmp = a; + a = b; + b = tmp; +} + +template +static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) { + // bitonic sort + int col = threadIdx.x; + int row = blockIdx.y; + + if (col >= ncols) return; + + const float * x_row = x + row * ncols; + int * dst_row = dst + row * ncols; + + // initialize indices + if (col < ncols) { + dst_row[col] = col; + } + __syncthreads(); + + for (int k = 2; k <= ncols; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + swap(dst_row[col], dst_row[ixj]); + } + } else { + if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + swap(dst_row[col], dst_row[ixj]); + } + } + } + __syncthreads(); + } + } +} + static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { const int col = blockDim.y*blockIdx.y + threadIdx.y; const int row = blockDim.x*blockIdx.x + threadIdx.x; @@ -4833,8 +4933,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int } const int i = row*ncols + col; - // dst[i] = col > n_past + row ? -INFINITY : x[i]; - dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU + //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i]; + //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU + dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; } static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) { @@ -4956,25 +5057,119 @@ static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const k_get_rows<<>>(x, y, dst, ncols); } -static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { - const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; - add_f32<<>>(x, y, dst, kx, ky); -} +template +struct bin_bcast_cuda { + template + void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, + const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, + cudaStream_t stream) { -static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; - add_f16_f32_f16<<>>(x, y, dst, k); -} + GGML_TENSOR_BINARY_OP_LOCALS -static void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; - add_f16_f32_f32<<>>(x, y, dst, k); -} -static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { - const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; - mul_f32<<>>(x, y, dst, kx, ky); -} + int nr0 = ne10/ne0; + int nr1 = ne11/ne1; + int nr2 = ne12/ne2; + int nr3 = ne13/ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + // collapse dimensions until first broadcast dimension + int64_t cne0[] = {ne0, ne1, ne2, ne3}; + int64_t cne1[] = {ne10, ne11, ne12, ne13}; + size_t cnb0[] = {nb0, nb1, nb2, nb3}; + size_t cnb1[] = {nb10, nb11, nb12, nb13}; + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; + }; + + auto collapse_nb = [](size_t cnb[], int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; + + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne0); + collapse(cne1); + } + } + { + int64_t ne0 = cne0[0]; + int64_t ne1 = cne0[1]; + int64_t ne2 = cne0[2]; + int64_t ne3 = cne0[3]; + + int64_t ne10 = cne1[0]; + int64_t ne11 = cne1[1]; + int64_t ne12 = cne1[2]; + int64_t ne13 = cne1[3]; + + //size_t nb0 = cnb0[0]; + size_t nb1 = cnb0[1]; + size_t nb2 = cnb0[2]; + size_t nb3 = cnb0[3]; + + //size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; + + //size_t s0 = nb0 / sizeof(src1_t); + size_t s1 = nb1 / sizeof(src1_t); + size_t s2 = nb2 / sizeof(src1_t); + size_t s3 = nb3 / sizeof(src1_t); + + //size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + + const int block_size = 128; + + int64_t hne0 = std::max(ne0/2LL, 1LL); + + dim3 block_dims; + block_dims.x = std::min(hne0, block_size); + block_dims.y = std::min(ne1, block_size / block_dims.x); + block_dims.z = std::min(std::min(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U); + + dim3 block_nums( + (hne0 + block_dims.x - 1) / block_dims.x, + (ne1 + block_dims.y - 1) / block_dims.y, + (ne2*ne3 + block_dims.z - 1) / block_dims.z + ); + + if (block_nums.z > 65535) { + // this is the maximum number of blocks in z direction, fallback to 1D grid kernel + int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; + k_bin_bcast_unravel<<>>( + src0_dd, src1_dd, dst_dd, + ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s10, */ s11, s12, s13); + } else { + k_bin_bcast<<>>( + src0_dd, src1_dd, dst_dd, + ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s10, */ s11, s12, s13); + } + } + } +}; static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; @@ -4996,14 +5191,14 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t sqr_f32<<>>(x, dst, k); } -static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - norm_f32<<>>(x, dst, ncols); + norm_f32<<>>(x, dst, ncols, eps); } else { const dim3 block_dims(1024, 1, 1); - norm_f32<1024><<>>(x, dst, ncols); + norm_f32<1024><<>>(x, dst, ncols, eps); } } @@ -5025,38 +5220,14 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con quantize_q8_1<<>>(x, vy, kx, kx_padded); } -template -static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +template +static __host__ __device__ void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); + dequantize_block<<>>(vx, y, k); } template -static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); -} - -template -static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); -} - -template -static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); -} - -template -static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); -} - -template -static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static __host__ __device__ void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q2_K<<>>(vx, y); @@ -5066,7 +5237,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static __host__ __device__ void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q3_K<<>>(vx, y); @@ -5076,13 +5247,13 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static __host__ __device__ void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_q4_K<<>>(vx, y); } template -static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static __host__ __device__ void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q5_K<<>>(vx, y); @@ -5092,7 +5263,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static __host__ __device__ void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q6_K<<>>(vx, y); @@ -5101,6 +5272,64 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu #endif } +static to_fp16_cuda_t __host__ __device__ ggml_get_to_fp16_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_block_cuda; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_K_cuda; + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_K_cuda; + case GGML_TYPE_F32: + return dequantize_block_cuda<1, 1, convert_f32>; + default: + return nullptr; + } +} + +static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_block_cuda; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_K_cuda; + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_K_cuda; + case GGML_TYPE_F16: + return dequantize_block_cuda<1, 1, convert_f16>; + default: + return nullptr; + } +} + static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; @@ -5189,6 +5418,15 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); } +static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec<1, 1, convert_f16> + <<>>(vx, y, dst, ncols, nrows); +} + static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; @@ -5279,83 +5517,6 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } -static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<1, 1, convert_f16><<>>(vx, y, k); -} - -static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; - dequantize_block<1, 1, convert_f32><<>>(vx, y, k); -} - -static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec<1, 1, convert_f16> - <<>>(vx, y, dst, ncols, nrows); -} - -static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { - switch (type) { - case GGML_TYPE_Q4_0: - return dequantize_row_q4_0_cuda; - case GGML_TYPE_Q4_1: - return dequantize_row_q4_1_cuda; - case GGML_TYPE_Q5_0: - return dequantize_row_q5_0_cuda; - case GGML_TYPE_Q5_1: - return dequantize_row_q5_1_cuda; - case GGML_TYPE_Q8_0: - return dequantize_row_q8_0_cuda; - case GGML_TYPE_Q2_K: - return dequantize_row_q2_K_cuda; - case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_cuda; - case GGML_TYPE_Q4_K: - return dequantize_row_q4_K_cuda; - case GGML_TYPE_Q5_K: - return dequantize_row_q5_K_cuda; - case GGML_TYPE_Q6_K: - return dequantize_row_q6_K_cuda; - case GGML_TYPE_F32: - return convert_fp32_to_fp16_cuda; - default: - return nullptr; - } -} - -static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { - switch (type) { - case GGML_TYPE_Q4_0: - return dequantize_row_q4_0_cuda; - case GGML_TYPE_Q4_1: - return dequantize_row_q4_1_cuda; - case GGML_TYPE_Q5_0: - return dequantize_row_q5_0_cuda; - case GGML_TYPE_Q5_1: - return dequantize_row_q5_1_cuda; - case GGML_TYPE_Q8_0: - return dequantize_row_q8_0_cuda; - case GGML_TYPE_Q2_K: - return dequantize_row_q2_K_cuda; - case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_cuda; - case GGML_TYPE_Q4_K: - return dequantize_row_q4_K_cuda; - case GGML_TYPE_Q5_K: - return dequantize_row_q5_K_cuda; - case GGML_TYPE_Q6_K: - return dequantize_row_q6_K_cuda; - case GGML_TYPE_F16: - return convert_fp16_to_fp32_cuda; - default: - return nullptr; - } -} - static void ggml_mul_mat_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -5923,26 +6084,20 @@ static void rope_cuda( template static void rope_neox_cuda( - const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, + const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream ) { GGML_ASSERT(ncols % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nrows, num_blocks_x, 1); - - const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float inv_ndims = -1.0f / n_dims; - if (pos == nullptr) { rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims + x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims ); } else { rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims + x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims ); } } @@ -5967,6 +6122,27 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const alibi_f32<<>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1); } +static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(1, nrows, 1); + k_sum_rows_f32<<>>(x, dst, ncols); +} + +static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { + // bitonic sort requires ncols to be power of 2 + GGML_ASSERT((ncols & (ncols - 1)) == 0); + + const dim3 block_dims(ncols, 1, 1); + const dim3 block_nums(1, nrows, 1); + if (order == GGML_SORT_ASC) { + k_argsort_f32_i32<<>>(x, dst, ncols); + } else if (order == GGML_SORT_DESC) { + k_argsort_f32_i32<<>>(x, dst, ncols); + } else { + GGML_ASSERT(false); + } +} + static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1); const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; @@ -6059,7 +6235,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { return ptr; } #ifdef DEBUG_CUDA_MALLOC - fprintf(stderr, "%s: %d buffers, max_size = %u MiB, tot_size = %u MiB, requested %u MiB\n", __func__, nnz, + fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz, (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024)); #endif void * ptr; @@ -6197,7 +6373,7 @@ void * ggml_cuda_host_malloc(size_t size) { // The allocation error can be bypassed. A null ptr will assigned out of this function. // This can fixed the OOM error in WSL. cudaGetLastError(); - fprintf(stderr, "WARNING: failed to allocate %.2f MiB of pinned memory: %s\n", + fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", size/1024.0/1024.0, cudaGetErrorString(err)); return nullptr; } @@ -6237,81 +6413,23 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( const enum ggml_type type = src->type; const int64_t ts = ggml_type_size(type); const int64_t bs = ggml_blck_size(type); - const int64_t i1_diff = i1_high - i1_low; + int64_t i1_diff = i1_high - i1_low; const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; - if (nb0 == ts && nb1 == ts*(ne0/bs)) { + if (nb0 == ts && nb1 == ts*ne0/bs) { return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream); - } - if (nb0 == ts) { - return cudaMemcpy2DAsync(dst_ptr, ts*(ne0/bs), x, nb1, ts*(ne0/bs), i1_diff, kind, stream); - } - GGML_ASSERT(bs == 1 && "TODO: implement bs != 1"); - for (int64_t i1 = 0; i1 < i1_diff; i1++) { - const void * rx = (const void *) ((const char *) x + i1*nb1); - void * rd = (void *) (dst_ptr + i1*ts*ne0); - // pretend the row is a matrix with cols=1 - cudaError_t r = cudaMemcpy2DAsync(rd, ts, rx, nb0, ts, ne0, kind, stream); - if (r != cudaSuccess) { return r; } - } - return cudaSuccess; -} - -static void ggml_cuda_op_repeat( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) { - // guaranteed to be an integer due to the check in ggml_can_repeat - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const int nr0 = (int)(ne0/ne00); - const int nr1 = (int)(ne1/ne01); - const int nr2 = (int)(ne2/ne02); - const int nr3 = (int)(ne3/ne03); - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - // TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne03; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne02; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne01; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - CUDA_CHECK(cudaMemcpyAsync( - (char *) dst_d + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0, - (const char *) src0_d + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01, - ne00*nb0, cudaMemcpyDeviceToDevice, stream)); - } - } - } - } - } + } else if (nb0 == ts) { + return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream); + } else { + for (int64_t i1 = 0; i1 < i1_diff; i1++) { + const void * rx = (const void *) ((const char *) x + i1*nb1); + void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); + // pretend the row is a matrix with cols=1 + cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream); + if (r != cudaSuccess) return r; } + return cudaSuccess; } - - (void) src1; - (void) src1_d; } static void ggml_cuda_op_get_rows( @@ -6358,44 +6476,55 @@ static void ggml_cuda_op_get_rows( } } -inline void ggml_cuda_op_add( +template +inline void ggml_cuda_op_bin_bcast( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { GGML_ASSERT(src1->type == GGML_TYPE_F32); - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream); + op()(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream); + op()(src0, src1, dst, (const half *) src0_dd, src1_dd, (half *) dst_dd, main_stream); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - add_f16_f32_f32_cuda((const half *) src0_dd, src1_dd, dst_dd, ggml_nelements(src0), main_stream); + op()(src0, src1, dst, (const half *) src0_dd, src1_dd, dst_dd, main_stream); } else { - fprintf(stderr, "src0->type: %d dst->type: %d\n", src0->type, dst->type); + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); GGML_ASSERT(false); } +} + +static void ggml_cuda_op_repeat( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) { + + ggml_cuda_op_bin_bcast>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream); (void) src1; - (void) dst; + (void) src1_d; +} + +inline void ggml_cuda_op_add( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + ggml_cuda_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } inline void ggml_cuda_op_mul( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + ggml_cuda_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; +inline void ggml_cuda_op_div( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - mul_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream); - - (void) dst; + ggml_cuda_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } inline void ggml_cuda_op_gelu( @@ -6464,7 +6593,10 @@ inline void ggml_cuda_op_norm( const int64_t ne00 = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); - norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream); + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream); (void) src1; (void) dst; @@ -6903,14 +7035,15 @@ inline void ggml_cuda_op_rope( GGML_ASSERT(false); rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream); } else if (is_neox) { + GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); if (src0->type == GGML_TYPE_F32) { rope_neox_cuda( - (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, main_stream ); } else if (src0->type == GGML_TYPE_F16) { rope_neox_cuda( - (const half *)src0_dd, (half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, + (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, attn_factor, corr_dims, main_stream ); } else { @@ -7007,6 +7140,42 @@ inline void ggml_cuda_op_im2col( (void) src0_dd; } +inline void ggml_cuda_op_sum_rows( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + sum_rows_f32_cuda(src0_dd, dst_dd, ncols, nrows, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_argsort( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + + argsort_f32_i32_cuda(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + inline void ggml_cuda_op_diag_mask_inf( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -7215,7 +7384,7 @@ static void ggml_cuda_op_mul_mat( const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; - // const int64_t nrows0 = ggml_nrows(src0); + const int64_t nrows0 = ggml_nrows(src0); const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; @@ -7523,6 +7692,10 @@ static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, gg ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul); } +static void ggml_cuda_div(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_div); +} + static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu); } @@ -7548,7 +7721,7 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src } bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - if (!g_cublas_loaded) { return false; } + if (!g_cublas_loaded) return false; const int64_t ne10 = src1->ne[0]; @@ -7626,7 +7799,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); } -__global__ static void k_compute_batched_ptrs( +static __global__ void k_compute_batched_ptrs( const half * src0_as_f16, const half * src1_as_f16, half * dst_f16, const void ** ptrs_src, void ** ptrs_dst, int ne12, int ne13, @@ -7682,9 +7855,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUDA_CHECK(ggml_cuda_set_device(g_main_device)); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - int id; - CUDA_CHECK(cudaGetDevice(&id)); - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream)); ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; void * src0_ddq = src0_extra->data_device[g_main_device]; @@ -7741,7 +7912,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx CUBLAS_CHECK( - cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB @@ -7775,7 +7946,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUDA_CHECK(cudaGetLastError()); CUBLAS_CHECK( - cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half), (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float), @@ -7845,11 +8016,10 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 #ifdef GGML_CUDA_FORCE_DMMV const bool use_mul_mat_vec_q = false; #else - const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1; + const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type); #endif // GGML_CUDA_FORCE_DMMV if (use_mul_mat_vec_q) { - // NOTE: this kernel does not support ggml_nrows(src1) > 1 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); } else { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); @@ -7874,6 +8044,219 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } } +#if 0 +template +static __global__ void k_compute_batched_ptrs_id( + const void ** ptrs_src, void ** ptrs_dst, + int ne12, int ne13, + int ne23, + int nb02, int nb03, + int nb12, int nb13, + int nb2, int nb3, + int r2, int r3, + ggml_type src0_type, half * src0_as_f16, int64_t src0_ne, + const half * src1_f16, half * dst_f16, + const int32_t * ids, const int id, + Srcs... src0s) { + + int i = ids[id]; + + half * src0_f16; + const void * srcs_ar[] = { (const half *) src0s... }; + if (src0_type == GGML_TYPE_F16) { + src0_f16 = (half *) srcs_ar[i]; + } else { + src0_f16 = src0_as_f16; + if (threadIdx.x == 0 && threadIdx.y == 0) { + const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type); + to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget); + } + } + + int i13 = blockIdx.x * blockDim.x + threadIdx.x; + int i12 = blockIdx.y * blockDim.y + threadIdx.y; + + if (i13 >= ne13 || i12 >= ne12) { + return; + } + + int i03 = i13 / r3; + int i02 = i12 / r2; + + ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03; + ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2; + ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2; +} + +static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) { + const struct ggml_tensor * ids = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src00 = dst->src[2]; + + const int id = dst->op_params[0]; + + GGML_ASSERT(!ggml_is_transposed(src00)); + GGML_ASSERT(!ggml_is_transposed(src1)); + + GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00); + const int64_t ne01 = src00->ne[1]; + const int64_t ne02 = src00->ne[2]; + const int64_t ne03 = src00->ne[3]; + + //const int64_t nb01 = src00->nb[1]; + const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02); + const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03); + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + //const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); + const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); + + const int64_t ne1 = ggml_nelements(src1); + const int64_t ne = ggml_nelements(dst); + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream)); + + //ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + //void * src0_ddq = src0_extra->data_device[g_main_device]; + //half * src0_as_f16 = (half *) src0_ddq; + + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + // convert src1 to fp16 + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + + size_t src1_as = 0; + half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); + to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); + + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + + // use cublasGemmBatchedEx + const int ne23 = ne12*ne13; + + const void ** ptrs_src = nullptr; + void ** ptrs_dst = nullptr; + + size_t ptrs_src_s = 0; + size_t ptrs_dst_s = 0; + + ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s); + ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s); + + int64_t src0_ne = ggml_nelements(src00); + half * src0_as_f16 = nullptr; + size_t src0_as = 0; + if (src00->type != GGML_TYPE_F16) { + src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as); + } + + static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6"); + dim3 block_dims(ne13, ne12); + k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>( + ptrs_src, ptrs_dst, + ne12, ne13, + ne23, + ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half), + nb12, nb13, + dst->nb[2], dst->nb[3], + r2, r3, + src00->type, src0_as_f16, src0_ne, + src1_as_f16, dst_f16, + (const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id, + dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr, + dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr, + dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr, + dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr + ); + CUDA_CHECK(cudaGetLastError()); + + CUBLAS_CHECK( + cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00, + (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10, + &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01, + ne23, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + if (src0_as != 0) { + ggml_cuda_pool_free(src0_as_f16, src0_as); + } + if (ptrs_src_s != 0) { + ggml_cuda_pool_free(ptrs_src, ptrs_src_s); + } + if (ptrs_dst_s != 0) { + ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s); + } + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); + + ggml_cuda_pool_free(src1_as_f16, src1_as); + ggml_cuda_pool_free(dst_f16, dst_as); +} +#endif + +static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) { +#if 0 +//#ifdef CUDA_USE_TENSOR_CORES +// const bool use_tensor_cores = true; +//#else +// const bool use_tensor_cores = false; +//#endif + + ggml_cuda_mul_mat_id_cublas(dst); + + // TODO: mmq/mmv support +#else + const struct ggml_tensor * ids = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const int id = dst->op_params[0]; + + int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device]; + + int32_t a_id; + CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0])); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0])); + + GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]); + const struct ggml_tensor * src0 = dst->src[a_id + 2]; + + ggml_cuda_mul_mat(src0, src1, dst); +#endif + + (void) _src0; + (void) _src1; +} + static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } @@ -7965,6 +8348,16 @@ static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col); } +static void ggml_cuda_sum_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sum_rows); +} + +static void ggml_cuda_argsort(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_argsort); +} + static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { (void) src0; (void) src1; @@ -8220,8 +8613,9 @@ void ggml_cuda_set_main_device(const int main_device) { main_device, g_device_count, g_main_device); return; } - g_main_device = main_device; - if (g_device_count > 1) { + + if (g_main_device != main_device && g_device_count > 1) { + g_main_device = main_device; cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, g_main_device)); fprintf(stderr, "%s: using device %d (%s) as main device\n", __func__, g_main_device, prop.name); @@ -8247,7 +8641,7 @@ void ggml_cuda_free_scratch() { } bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { - if (!g_cublas_loaded) { return false; } + if (!g_cublas_loaded) return false; ggml_cuda_func_t func; const bool any_on_device = tensor->backend == GGML_BACKEND_GPU @@ -8283,6 +8677,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_MUL: func = ggml_cuda_mul; break; + case GGML_OP_DIV: + func = ggml_cuda_div; + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_GELU: @@ -8296,7 +8693,8 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ break; default: return false; - } break; + } + break; case GGML_OP_NORM: func = ggml_cuda_norm; break; @@ -8309,6 +8707,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ } func = ggml_cuda_mul_mat; break; + case GGML_OP_MUL_MAT_ID: + if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[2], tensor->src[1], tensor)) { + return false; + } + func = ggml_cuda_mul_mat_id; + break; case GGML_OP_SCALE: func = ggml_cuda_scale; break; @@ -8348,6 +8752,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_IM2COL: func = ggml_cuda_im2col; break; + case GGML_OP_SUM_ROWS: + func = ggml_cuda_sum_rows; + break; + case GGML_OP_ARGSORT: + func = ggml_cuda_argsort; + break; default: return false; } @@ -8364,7 +8774,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ int ggml_cuda_get_device_count() { int device_count; - CUDA_CHECK(cudaGetDeviceCount(&device_count)); + if (cudaGetDeviceCount(&device_count) != cudaSuccess) { + return 0; + } return device_count; } @@ -8380,27 +8792,16 @@ void ggml_cuda_get_device_description(int device, char * description, size_t des #define UNUSED GGML_UNUSED -struct ggml_backend_context_cuda { -}; - -static const char * ggml_backend_cuda_name(ggml_backend_t backend) { - return GGML_CUDA_NAME; - - UNUSED(backend); -} - -static void ggml_backend_cuda_free(ggml_backend_t backend) { - ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; - delete cuda_ctx; - delete backend; -} +// cuda buffer struct ggml_backend_buffer_context_cuda { - void * device; - + int device; + void * dev_ptr = nullptr; ggml_tensor_extra_gpu * temp_tensor_extras = nullptr; size_t temp_tensor_extra_index = 0; + ggml_backend_buffer_context_cuda(int device, void * dev_ptr) : device(device), dev_ptr(dev_ptr) {} + ~ggml_backend_buffer_context_cuda() { delete[] temp_tensor_extras; } @@ -8421,16 +8822,103 @@ struct ggml_backend_buffer_context_cuda { static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; - CUDA_CHECK(cudaFree(ctx->device)); + CUDA_CHECK(cudaFree(ctx->dev_ptr)); delete ctx; } static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) { ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; - return ctx->device; + return ctx->dev_ptr; } -static size_t ggml_backend_cuda_buffer_get_alloc_size(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; + + if (tensor->view_src != NULL && tensor->view_offs == 0) { + assert(tensor->view_src->buffer->buft == buffer->buft); // TODO + tensor->backend = tensor->view_src->backend; + tensor->extra = tensor->view_src->extra; + return; + } + + ggml_tensor_extra_gpu * extra = ctx->ggml_cuda_alloc_temp_tensor_extra(); + + extra->data_device[ctx->device] = tensor->data; + + tensor->backend = GGML_BACKEND_GPU; + tensor->extra = extra; + + if (ggml_is_quantized(tensor->type)) { + // initialize padding to 0 to avoid possible NaN values + int64_t row_low = 0; + int64_t row_high = ggml_nrows(tensor); + int64_t nrows_split = row_high - row_low; + + size_t original_size = ggml_nbytes_split(tensor, nrows_split); + size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor); + + if (padded_size > original_size && tensor->view_src == nullptr) { + CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[ctx->device][0])); + } + } + + UNUSED(buffer); +} + +static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + + CUDA_CHECK(cudaMemcpy((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice)); + + UNUSED(buffer); +} + +static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); + + CUDA_CHECK(cudaMemcpy(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost)); + + UNUSED(buffer); +} + +static struct ggml_backend_buffer_i cuda_backend_buffer_interface = { + /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer, + /* .get_base = */ ggml_backend_cuda_buffer_get_base, + /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor, + /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, + /* .cpy_tensor_from = */ NULL, + /* .cpy_tensor_to = */ NULL, +}; + +// cuda buffer type + +static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + int device = (int) (intptr_t) buft->context; + + ggml_cuda_set_device(device); + + size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0 + + void * dev_ptr; + CUDA_CHECK(cudaMalloc(&dev_ptr, size)); + + ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda(device, dev_ptr); + + return ggml_backend_buffer_init(buft, cuda_backend_buffer_interface, ctx, size); +} + +static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 128; + + UNUSED(buft); +} + +static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, ggml_tensor * tensor) { int64_t row_low = 0; int64_t row_high = ggml_nrows(tensor); int64_t nrows_split = row_high - row_low; @@ -8448,91 +8936,127 @@ static size_t ggml_backend_cuda_buffer_get_alloc_size(ggml_backend_buffer_t buff return size; - UNUSED(buffer); + UNUSED(buft); } -static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; +static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + return ggml_backend_is_cuda(backend); - if (tensor->view_src != NULL && tensor->view_offs == 0) { - assert(tensor->view_src->buffer->backend == buffer->backend); - tensor->backend = tensor->view_src->backend; - tensor->extra = tensor->view_src->extra; - return; - } - - ggml_tensor_extra_gpu * extra = ctx->ggml_cuda_alloc_temp_tensor_extra(); - - extra->data_device[g_main_device] = tensor->data; - - tensor->backend = GGML_BACKEND_GPU; - tensor->extra = extra; - - if (ggml_is_quantized(tensor->type)) { - // initialize padding to 0 to avoid possible NaN values - int64_t row_low = 0; - int64_t row_high = ggml_nrows(tensor); - int64_t nrows_split = row_high - row_low; - - size_t original_size = ggml_nbytes_split(tensor, nrows_split); - size_t padded_size = ggml_backend_cuda_buffer_get_alloc_size(tensor->buffer, tensor); - - if (padded_size > original_size && tensor->view_src == nullptr) { - CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + original_size, 0, padded_size - original_size, g_cudaStreams[g_main_device][0])); - } - } - - UNUSED(buffer); + UNUSED(buft); } -static struct ggml_backend_buffer_i cuda_backend_buffer_interface = { - /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer, - /* .get_base = */ ggml_backend_cuda_buffer_get_base, - /* .get_alloc_size = */ ggml_backend_cuda_buffer_get_alloc_size, - /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor, - /* .free_tensor = */ NULL, +static ggml_backend_buffer_type_i cuda_backend_buffer_type_interface = { + /* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment, + /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, + /* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend, }; -static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) { - ggml_cuda_set_device(g_main_device); +ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { + static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda[GGML_CUDA_MAX_DEVICES]; + static bool ggml_backend_buffer_type_cuda_initialized = false; + if (!ggml_backend_buffer_type_cuda_initialized) { + for (int i = 0; i < GGML_CUDA_MAX_DEVICES; i++) { + ggml_backend_buffer_type_cuda[i] = { + /* .iface = */ cuda_backend_buffer_type_interface, + /* .context = */ (ggml_backend_buffer_type_context_t) (intptr_t) i, + }; + } + ggml_backend_buffer_type_cuda_initialized = true; + } - ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda; - - size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0 - - ggml_cuda_set_device(g_main_device); - CUDA_CHECK(cudaMalloc(&ctx->device, size)); - - return ggml_backend_buffer_init(backend, cuda_backend_buffer_interface, ctx, size); + return &ggml_backend_buffer_type_cuda[device]; } -static size_t ggml_backend_cuda_get_alignment(ggml_backend_t backend) { - return 128; +// host buffer type + +static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; + CUDA_CHECK(cudaFreeHost(ctx->dev_ptr)); + delete ctx; +} + +static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + void * ptr; + CUDA_CHECK(cudaMallocHost(&ptr, size)); + + // FIXME: this is a hack to avoid having to implement a new buffer type + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer; + + return buffer; + + UNUSED(buft); +} + +struct ggml_backend_buffer_type_i cuda_backend_host_buffer_type_interface = { + /* .alloc_buffer = */ ggml_backend_cuda_host_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment, + /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, + /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend, +}; + +ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() { + static struct ggml_backend_buffer_type ggml_backend_buffer_type_cuda_host = { + /* .iface = */ cuda_backend_host_buffer_type_interface, + /* .context = */ nullptr, + }; + + return &ggml_backend_buffer_type_cuda_host; +} + +// backend + +struct ggml_backend_context_cuda { + int device; +}; + +static const char * ggml_backend_cuda_name(ggml_backend_t backend) { + return GGML_CUDA_NAME; + UNUSED(backend); } +static void ggml_backend_cuda_free(ggml_backend_t backend) { + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; + + delete cuda_ctx; + delete backend; +} + +static ggml_backend_buffer_type_t ggml_backend_cuda_get_default_buffer_type(ggml_backend_t backend) { + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; + + return ggml_backend_cuda_buffer_type(cuda_ctx->device); +} + static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; + + GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); - CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[g_main_device][0])); - - UNUSED(backend); + CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, g_cudaStreams[cuda_ctx->device][0])); } static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; + + GGML_ASSERT(tensor->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); - CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0])); - - UNUSED(backend); + CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, g_cudaStreams[cuda_ctx->device][0])); } static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { - CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0])); + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; + + CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[cuda_ctx->device][0])); UNUSED(backend); } @@ -8546,14 +9070,14 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen UNUSED(cgraph); } -[[noreturn]] static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { +static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { GGML_ASSERT(!"not implemented"); UNUSED(backend); UNUSED(plan); } -[[noreturn]] static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { +static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { GGML_ASSERT(!"not implemented"); UNUSED(backend); @@ -8561,7 +9085,9 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen } static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { - ggml_cuda_set_device(g_main_device); + ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; + + ggml_cuda_set_main_device(cuda_ctx->device); ggml_compute_params params = {}; params.type = GGML_TASK_COMPUTE; @@ -8569,13 +9095,18 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; - if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE) { + if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE) continue; - } + assert(node->backend == GGML_BACKEND_GPU); + assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); + assert(node->extra != nullptr); + for (int j = 0; j < GGML_MAX_SRC; j++) { if (node->src[j] != nullptr) { assert(node->src[j]->backend == GGML_BACKEND_GPU); + assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); + assert(node->src[j]->extra != nullptr); } } @@ -8612,27 +9143,98 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph UNUSED(backend); } +static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + return true; + default: + return false; + } + break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + { + struct ggml_tensor * a; + struct ggml_tensor * b; + if (op->op == GGML_OP_MUL_MAT) { + a = op->src[0]; + b = op->src[1]; + } else { + a = op->src[2]; + b = op->src[1]; + } + if (a->ne[3] != b->ne[3]) { + return false; + } + return true; + } break; + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NORM: + case GGML_OP_REPEAT: + case GGML_OP_GET_ROWS: + case GGML_OP_DUP: + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_RMS_NORM: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_CLAMP: + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_ROPE: + case GGML_OP_ALIBI: + case GGML_OP_IM2COL: + case GGML_OP_SUM_ROWS: + case GGML_OP_ARGSORT: + return true; + default: + return false; + } + + UNUSED(backend); +} + static ggml_backend_i cuda_backend_i = { - /* .get_name = */ ggml_backend_cuda_name, - /* .free = */ ggml_backend_cuda_free, - /* .alloc_buffer = */ ggml_backend_cuda_alloc_buffer, - /* .get_alignment = */ ggml_backend_cuda_get_alignment, - /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, - /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, - /* .synchronize = */ ggml_backend_cuda_synchronize, - /* .cpy_tensor_from = */ nullptr, - /* .cpy_tensor_to = */ nullptr, - /* .graph_plan_create = */ ggml_backend_cuda_graph_plan_create, - /* .graph_plan_free = */ ggml_backend_cuda_graph_plan_free, - /* .graph_plan_compute = */ ggml_backend_cuda_graph_plan_compute, - /* .graph_compute = */ ggml_backend_cuda_graph_compute, - /* .supports_op = */ nullptr, + /* .get_name = */ ggml_backend_cuda_name, + /* .free = */ ggml_backend_cuda_free, + /* .get_default_buffer_type = */ ggml_backend_cuda_get_default_buffer_type, + /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, + /* .cpy_tensor_from_async = */ NULL, + /* .cpy_tensor_to_async = */ NULL, + /* .synchronize = */ ggml_backend_cuda_synchronize, + /* .graph_plan_create = */ ggml_backend_cuda_graph_plan_create, + /* .graph_plan_free = */ ggml_backend_cuda_graph_plan_free, + /* .graph_plan_compute = */ ggml_backend_cuda_graph_plan_compute, + /* .graph_compute = */ ggml_backend_cuda_graph_compute, + /* .supports_op = */ ggml_backend_cuda_supports_op, }; -ggml_backend_t ggml_backend_cuda_init() { +ggml_backend_t ggml_backend_cuda_init(int device) { ggml_init_cublas(); // TODO: remove from ggml.c - ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda; + if (device < 0 || device >= ggml_cuda_get_device_count()) { + fprintf(stderr, "%s: error: invalid device %d\n", __func__, device); + return nullptr; + } + + // not strictly necessary, but it may reduce the overhead of the first graph_compute + ggml_cuda_set_main_device(device); + + ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda { + /* .device = */ device + }; ggml_backend_t cuda_backend = new ggml_backend { /* .interface = */ cuda_backend_i, @@ -8641,3 +9243,27 @@ ggml_backend_t ggml_backend_cuda_init() { return cuda_backend; } + +bool ggml_backend_is_cuda(ggml_backend_t backend) { + return backend->iface.get_name == ggml_backend_cuda_name; +} + +static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * user_data) { + ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data); + return cuda_backend; + + UNUSED(params); +} + +static int ggml_backend_cuda_reg_devices() { + int device_count = ggml_cuda_get_device_count(); + //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization + for (int i = 0; i < device_count; i++) { + char name[128]; + snprintf(name, sizeof(name), "%s%d", GGML_CUDA_NAME, i); + ggml_backend_register(name, ggml_backend_reg_cuda_init, ggml_backend_cuda_buffer_type(i), (void *) (intptr_t) i); + } + return device_count; +} + +GGML_CONSTRUCTOR(ggml_backend_cuda_reg_devices) diff --git a/ggml-cuda.h b/ggml-cuda.h index 528e66c33..cdb0c0c41 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -49,7 +49,15 @@ GGML_API int ggml_cuda_get_device_count(void); GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size); // backend API -GGML_API ggml_backend_t ggml_backend_cuda_init(void); // TODO: take a list of devices to use +GGML_API ggml_backend_t ggml_backend_cuda_init(int device); + +GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend); +GGML_API int ggml_backend_cuda_get_device(ggml_backend_t backend); + +GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); + +// pinned host buffer for use with CPU backend for faster copies between CPU and GPU +GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void); #ifdef __cplusplus }