From 680e6f91775f972f0df34f56807f30826370db59 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 12 Jul 2023 20:26:18 +0300 Subject: [PATCH] cuda : add gelu support --- ggml-cuda.cu | 53 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 89e69bdc1..dc4b773a6 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -212,6 +212,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_ADD_BLOCK_SIZE 256 #define CUDA_MUL_BLOCK_SIZE 256 +#define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_SCALE_BLOCK_SIZE 256 @@ -266,6 +267,20 @@ static __global__ void mul_f32(const float * x, const float * y, float * dst, co dst[i] = x[i] * y[i%ky]; } +static const float GELU_COEF_A = 0.044715f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +static __global__ void gelu_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + float xi = x[i]; + dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); +} + static __global__ void silu_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -1733,6 +1748,11 @@ static void mul_f32_cuda(const float * x, const float * y, float * dst, const in mul_f32<<>>(x, y, dst, kx, ky); } +static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; + gelu_f32<<>>(x, dst, k); +} + static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; silu_f32<<>>(x, dst, k); @@ -2327,6 +2347,28 @@ inline void ggml_cuda_op_mul( (void) i02; } +inline void ggml_cuda_op_gelu( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, + float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t & cudaStream_main){ + + GGML_ASSERT(src0_ddf_i != nullptr); + GGML_ASSERT(dst_ddf_i != nullptr); + + const int64_t ne00 = src0->ne[0]; + const int64_t i01_diff = i01_high - i01_low; + + // compute + gelu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main); + + (void) src1; + (void) dst; + (void) src0_ddq_i; + (void) src1_ddf_i; + (void) i02; + (void) i1; +} + inline void ggml_cuda_op_silu( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, @@ -2986,6 +3028,11 @@ void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten } +void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_gelu, true, true); +} + void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true); @@ -3382,6 +3429,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ } func = ggml_cuda_mul; break; + case GGML_OP_GELU: + if (!any_on_device) { + return false; + } + func = ggml_cuda_gelu; + break; case GGML_OP_SILU: if (!any_on_device) { return false;