From 2fffa0d61fa10e4b466e78cabcc6a4e16717b580 Mon Sep 17 00:00:00 2001 From: cebtenzzre Date: Thu, 2 Nov 2023 01:49:44 -0400 Subject: [PATCH] cuda : fix RoPE after #2268 (#3897) --- ggml-cuda.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 61cd1747c..57a528ede 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4539,7 +4539,7 @@ static __global__ void rope( const int i2 = row/p_delta_rows; const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*powf(freq_base, -col/ncols); + const float theta_base = p*powf(freq_base, -float(col)/ncols); float cos_theta, sin_theta; rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta); @@ -4566,8 +4566,8 @@ static __global__ void rope_neox( const int i = row*ncols + col/2; const int i2 = row/p_delta_rows; - // simplified from `(row * ncols + col) * (-1 / ncols)` - const float cur_rot = -col/ncols - row; + // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero + const float cur_rot = -float(col)/ncols; const int p = has_pos ? pos[i2] : 0; const float theta_base = p*powf(freq_base, cur_rot);