Add ReLU and SQR CUDA ops to (partially) fix Persimmon offloading (#4041)

* Add ReLU and SQR CUDA ops to fix Persimmon offloading

* Persimmon loader: More helpful error on CUDA/ROCM when offloading too many layers
This commit is contained in:
Kerfuffle 2023-11-13 01:58:15 -07:00 committed by GitHub
parent 21fd874c8d
commit bb50a792ec
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 0 deletions

View file

@ -433,6 +433,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_MUL_BLOCK_SIZE 256
#define CUDA_GELU_BLOCK_SIZE 256
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_CPY_BLOCK_SIZE 32
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_CLAMP_BLOCK_SIZE 256
@ -553,6 +555,24 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
static __global__ void relu_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = fmaxf(x[i], 0);
}
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] * x[i];
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
@ -4759,6 +4779,16 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
@ -6128,6 +6158,34 @@ inline void ggml_cuda_op_silu(
(void) src1_dd;
}
inline void ggml_cuda_op_relu(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_cuda_op_sqr(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
inline void ggml_cuda_op_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@ -7160,6 +7218,14 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
}
static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
}
static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
}
static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
}
@ -7891,6 +7957,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_UNARY_OP_SILU:
func = ggml_cuda_silu;
break;
case GGML_UNARY_OP_RELU:
func = ggml_cuda_relu;
break;
default:
return false;
} break;
@ -7909,6 +7978,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
case GGML_OP_SCALE:
func = ggml_cuda_scale;
break;
case GGML_OP_SQR:
func = ggml_cuda_sqr;
break;
case GGML_OP_CLAMP:
if (!any_on_device) {
return false;

View file

@ -2877,6 +2877,13 @@ static void llm_load_tensors(
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > int(n_layer + 1)) {
LLAMA_LOG_ERROR("%s: CUDA backend missing Persimmon CUDA ops, can offload at most %ld layers. See: https://github.com/ggerganov/llama.cpp/issues/4038\n",
__func__, n_layer + 1);
throw std::runtime_error("Persimmon CUDA offload failed");
}
#endif
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32