From 800c9635b4a9390126f397870f3a825fc7455bd1 Mon Sep 17 00:00:00 2001 From: Jiahao Li Date: Wed, 23 Aug 2023 02:27:06 +0800 Subject: [PATCH] Fix CUDA softmax by subtracting max value before exp (#2665) --- ggml-cuda.cu | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 8ab29bb20..4fe378c21 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3979,24 +3979,29 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int // the CUDA soft max implementation differs from the CPU implementation // instead of doubles floats are used -// values are also not normalized to the maximum value by subtracting it in the exponential function -// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { const int row = blockDim.x*blockIdx.x + threadIdx.x; const int block_size = blockDim.y; const int tid = threadIdx.y; - float tmp = 0.0; - - for (int block_start = 0; block_start < ncols; block_start += block_size) { - const int col = block_start + tid; - - if (col >= ncols) { - break; - } + float max_val = -INFINITY; + for (int col = tid; col < ncols; col += block_size) { const int i = row*ncols + col; - const float val = expf(x[i]); + max_val = max(max_val, x[i]); + } + + // find the max value in the block +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32)); + } + + float tmp = 0.f; + + for (int col = tid; col < ncols; col += block_size) { + const int i = row*ncols + col; + const float val = expf(x[i] - max_val); tmp += val; dst[i] = val; } @@ -4007,15 +4012,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } - for (int block_start = 0; block_start < ncols; block_start += block_size) { - const int col = block_start + tid; - - if (col >= ncols) { - break; - } + const float inv_tmp = 1.f / tmp; + for (int col = tid; col < ncols; col += block_size) { const int i = row*ncols + col; - dst[i] /= tmp; + dst[i] *= inv_tmp; } }