From e29757e0f7535d1ac314300f0324684cc785e06c Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Sat, 18 Nov 2023 04:20:19 -0700 Subject: [PATCH] ggml-cuda.cu: Move static items into anonymous namespace --- ggml-cuda.cu | 567 +++++++++++++++++++++++++++------------------------ 1 file changed, 299 insertions(+), 268 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 50e03de50..8f6e4d18c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -83,6 +83,11 @@ #include "ggml.h" #include "ggml-backend-impl.h" +#define START_ANONYMOUS_NAMESPACE namespace { +#define END_ANONYMOUS_NAMESPACE } + +START_ANONYMOUS_NAMESPACE + #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 #define CC_OFFSET_AMD 1000000 @@ -126,7 +131,7 @@ #endif typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); -static __device__ __forceinline__ int __vsubss4(const int a, const int b) { +__device__ __forceinline__ int __vsubss4(const int a, const int b) { const int8x4_t va = reinterpret_cast(a); const int8x4_t vb = reinterpret_cast(b); #if __has_builtin(__builtin_elementwise_sub_sat) @@ -146,7 +151,7 @@ static __device__ __forceinline__ int __vsubss4(const int a, const int b) { #endif // __has_builtin(__builtin_elementwise_sub_sat) } -static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { +__device__ __forceinline__ int __dp4a(const int a, const int b, int c) { #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) c = __builtin_amdgcn_sdot4(a, b, c, false); #elif defined(__gfx1100__) @@ -234,7 +239,7 @@ typedef float dfloat; // dequantize float typedef float2 dfloat2; #endif //GGML_CUDA_F16 -static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { +__device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment int x32 = 0; @@ -244,7 +249,7 @@ static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const return x32; } -static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { +__device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment int x32 = 0; @@ -254,11 +259,11 @@ static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, con return x32; } -static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { +__device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment } -static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { +__device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment } @@ -469,7 +474,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA #define MUL_MAT_SRC1_COL_STRIDE 128 #define MAX_STREAMS 8 -static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { { nullptr } }; +cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { { nullptr } }; struct ggml_tensor_extra_gpu { void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors @@ -489,18 +494,18 @@ inline cudaError_t ggml_cuda_set_device(const int device) { return cudaSetDevice(device); } -static int g_device_count = -1; -static int g_main_device = 0; -static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; -static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; +int g_device_count = -1; +int g_main_device = 0; +int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; +float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; -static void * g_scratch_buffer = nullptr; -static size_t g_scratch_size = 0; // disabled by default -static size_t g_scratch_offset = 0; +void * g_scratch_buffer = nullptr; +size_t g_scratch_size = 0; // disabled by default +size_t g_scratch_offset = 0; -static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; +cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; -static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { +__global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= kx) { @@ -509,7 +514,7 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co dst[i] = x[i] + y[i%ky]; } -static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { +__global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -518,7 +523,7 @@ static __global__ void add_f16_f32_f16(const half * x, const float * y, half * d dst[i] = __hadd(x[i], __float2half(y[i])); } -static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) { +__global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -527,7 +532,7 @@ static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst[i] = __half2float(x[i]) + y[i]; } -static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { +__global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= kx) { @@ -536,7 +541,7 @@ static __global__ void mul_f32(const float * x, const float * y, float * dst, co dst[i] = x[i] * y[i%ky]; } -static __global__ void gelu_f32(const float * x, float * dst, const int k) { +__global__ void gelu_f32(const float * x, float * dst, const int k) { const float GELU_COEF_A = 0.044715f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -549,7 +554,7 @@ static __global__ void gelu_f32(const float * x, float * dst, const int k) { dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))); } -static __global__ void silu_f32(const float * x, float * dst, const int k) { +__global__ void silu_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -558,7 +563,7 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) { dst[i] = x[i] / (1.0f + expf(-x[i])); } -static __global__ void relu_f32(const float * x, float * dst, const int k) { +__global__ void relu_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -567,7 +572,7 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) { dst[i] = fmaxf(x[i], 0); } -static __global__ void sqr_f32(const float * x, float * dst, const int k) { +__global__ void sqr_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -576,7 +581,7 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) { dst[i] = x[i] * x[i]; } -static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +__device__ __forceinline__ float2 warp_reduce_sum(float2 a) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); @@ -586,7 +591,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { } template -static __global__ void norm_f32(const float * x, float * dst, const int ncols) { +__global__ void norm_f32(const float * x, float * dst, const int ncols) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; @@ -623,7 +628,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) { } } -static __device__ __forceinline__ float warp_reduce_sum(float x) { +__device__ __forceinline__ float warp_reduce_sum(float x) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { x += __shfl_xor_sync(0xffffffff, x, mask, 32); @@ -632,7 +637,7 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) { } template -static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { +__global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; @@ -665,7 +670,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol } } -static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +__device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; const dfloat d = x[ib].d; @@ -684,7 +689,7 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ +__device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx; const dfloat d = __low2half(x[ib].dm); @@ -704,7 +709,7 @@ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +__device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q5_0 * x = (const block_q5_0 *) vx; const dfloat d = x[ib].d; @@ -727,7 +732,7 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ +__device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q5_1 * x = (const block_q5_1 *) vx; const dfloat d = __low2half(x[ib].dm); @@ -751,7 +756,7 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in #endif // GGML_CUDA_F16 } -static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ +__device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q8_0 * x = (const block_q8_0 *) vx; const dfloat d = x[ib].d; @@ -770,7 +775,7 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in //================================== k-quants template -static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { +__global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int i = blockIdx.x; const block_q2_K * x = (const block_q2_K *) vx; @@ -804,7 +809,7 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t } template -static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { +__global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int i = blockIdx.x; const block_q3_K * x = (const block_q3_K *) vx; @@ -832,7 +837,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t const uint8_t * q = x[i].qs + 32*n; const uint8_t * hm = x[i].hmask; - for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); + for (int l = l0; l < l0+4; ++l) { y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); } #else const int tid = threadIdx.x; const int is = tid/16; // 0 or 1 @@ -858,7 +863,7 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t } #if QK_K == 256 -static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { +inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { d = q[j] & 63; m = q[j + 4] & 63; } else { @@ -869,7 +874,7 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t #endif template -static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { +__global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q4_K * x = (const block_q4_K *) vx; const int i = blockIdx.x; @@ -910,7 +915,7 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t } template -static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { +__global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q5_K * x = (const block_q5_K *) vx; const int i = blockIdx.x; @@ -957,7 +962,7 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t } template -static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { +__global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { const block_q6_K * x = (const block_q6_K *) vx; const int i = blockIdx.x; @@ -1001,12 +1006,12 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t #endif } -static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { +__global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; + if (row > nrows) { return; } const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; @@ -1107,10 +1112,10 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, } } -static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { +__global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; + if (row > nrows) { return; } const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; @@ -1211,10 +1216,10 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, } } -static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { +__global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; + if (row > nrows) { return; } const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; @@ -1347,7 +1352,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, } } -static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) { +__global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) { const int row = blockIdx.x; const int num_blocks_per_row = ncols / QK_K; @@ -1463,12 +1468,12 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, } } -static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { +__global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; + if (row > nrows) { return; } const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; @@ -1573,7 +1578,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, } } -static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ +__device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ const half * x = (const half *) vx; // automatic half -> float type cast if dfloat == float @@ -1581,7 +1586,7 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, v.y = x[ib + iqs + 1]; } -static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){ +__device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){ const float * x = (const float *) vx; // automatic half -> float type cast if dfloat == float @@ -1589,7 +1594,7 @@ static __device__ void convert_f32(const void * vx, const int ib, const int iqs, v.y = x[ib + iqs + 1]; } -static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { +__global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { const int ix = blockDim.x*blockIdx.x + threadIdx.x; if (ix >= kx_padded) { @@ -1629,7 +1634,7 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest } template -static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) { +__global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) { const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2; const int row = blockDim.y*blockIdx.y + threadIdx.y; @@ -1657,7 +1662,7 @@ static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst } template -static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { +__global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) { const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; if (i >= k) { @@ -1683,7 +1688,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 -template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( +template __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( const int * v, const int * u, const float & d4, const half2 & ds8) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics @@ -1712,7 +1717,7 @@ template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp #define VDR_Q4_1_Q8_1_MMVQ 2 #define VDR_Q4_1_Q8_1_MMQ 4 -template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( +template __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( const int * v, const int * u, const half2 & dm4, const half2 & ds8) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics @@ -1750,7 +1755,7 @@ template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp #define VDR_Q5_0_Q8_1_MMVQ 2 #define VDR_Q5_0_Q8_1_MMQ 4 -template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( +template __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics @@ -1786,7 +1791,7 @@ template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp #define VDR_Q5_1_Q8_1_MMVQ 2 #define VDR_Q5_1_Q8_1_MMQ 4 -template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( +template __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics @@ -1832,7 +1837,7 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp #define VDR_Q8_0_Q8_1_MMVQ 2 #define VDR_Q8_0_Q8_1_MMQ 8 -template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( +template __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( const int * v, const int * u, const float & d8_0, const float & d8_1) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics @@ -1851,7 +1856,7 @@ template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp #endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( +template __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( const int * v, const int * u, const half2 & dm8, const half2 & ds8) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics @@ -1886,7 +1891,7 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp #define VDR_Q2_K_Q8_1_MMQ 2 // contiguous v/x values -static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( +__device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, const half2 & dm2, const float * __restrict__ d8) { @@ -1919,7 +1924,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( } // contiguous u/y values -static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( +__device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, const half2 & dm2, const float & d8) { @@ -1960,7 +1965,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( #define VDR_Q3_K_Q8_1_MMQ 2 // contiguous v/x values -static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( +__device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, const int & scale_offset, const float & d3, const float * __restrict__ d8) { @@ -1998,7 +2003,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( } // contiguous u/y values -static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( +__device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, const float & d3, const float & d8) { @@ -2027,7 +2032,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( #define VDR_Q4_K_Q8_1_MMQ 8 // contiguous v/x values -static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( +__device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { @@ -2058,7 +2063,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( } // contiguous u/y values -static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( +__device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { @@ -2095,7 +2100,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( #define VDR_Q5_K_Q8_1_MMQ 8 // contiguous v/x values -static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( +__device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { @@ -2133,7 +2138,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( } // contiguous u/y values -static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( +__device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { @@ -2170,7 +2175,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( #define VDR_Q6_K_Q8_1_MMQ 8 // contiguous v/x values -static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( +__device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, const float & d, const float * __restrict__ d8) { @@ -2198,7 +2203,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( } // contiguous u/y values -static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( +__device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, const float & d6, const float * __restrict__ d8) { @@ -2229,7 +2234,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( #endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -static __device__ __forceinline__ float vec_dot_q4_0_q8_1( +__device__ __forceinline__ float vec_dot_q4_0_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; @@ -2247,7 +2252,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1( return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); } -template static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { +template __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { (void)x_qh; (void)x_sc; __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; @@ -2257,7 +2262,7 @@ template static __device__ __forceinline__ void allocate_tiles_q4_0( *x_dm = (half2 *) tile_x_d; } -template static __device__ __forceinline__ void load_tiles_q4_0( +template __device__ __forceinline__ void load_tiles_q4_0( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { (void)x_qh; (void)x_sc; @@ -2304,7 +2309,7 @@ template static __device__ __forceinlin } } -static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( +__device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { (void)x_qh; (void)x_sc; @@ -2325,7 +2330,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); } -static __device__ __forceinline__ float vec_dot_q4_1_q8_1( +__device__ __forceinline__ float vec_dot_q4_1_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; @@ -2343,7 +2348,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1( return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds); } -template static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { +template __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { (void)x_qh; (void)x_sc; __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y]; @@ -2353,7 +2358,7 @@ template static __device__ __forceinline__ void allocate_tiles_q4_1( *x_dm = tile_x_dm; } -template static __device__ __forceinline__ void load_tiles_q4_1( +template __device__ __forceinline__ void load_tiles_q4_1( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { (void)x_qh; (void)x_sc; @@ -2398,7 +2403,7 @@ template static __device__ __forceinlin } } -static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( +__device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { (void)x_qh; (void)x_sc; @@ -2418,7 +2423,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); } -static __device__ __forceinline__ float vec_dot_q5_0_q8_1( +__device__ __forceinline__ float vec_dot_q5_0_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; @@ -2438,7 +2443,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1( return vec_dot_q5_0_q8_1_impl(vl, vh, u, bq5_0->d, bq8_1->ds); } -template static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { +template __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { (void)x_qh; (void)x_sc; __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; @@ -2448,7 +2453,7 @@ template static __device__ __forceinline__ void allocate_tiles_q5_0( *x_dm = (half2 *) tile_x_d; } -template static __device__ __forceinline__ void load_tiles_q5_0( +template __device__ __forceinline__ void load_tiles_q5_0( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { (void)x_qh; (void)x_sc; @@ -2513,7 +2518,7 @@ template static __device__ __forceinlin } } -static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( +__device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { (void)x_qh; (void)x_sc; @@ -2535,7 +2540,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); } -static __device__ __forceinline__ float vec_dot_q5_1_q8_1( +__device__ __forceinline__ float vec_dot_q5_1_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; @@ -2555,7 +2560,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1( return vec_dot_q5_1_q8_1_impl(vl, vh, u, bq5_1->dm, bq8_1->ds); } -template static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { +template __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { (void)x_qh; (void)x_sc; __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; @@ -2565,7 +2570,7 @@ template static __device__ __forceinline__ void allocate_tiles_q5_1( *x_dm = tile_x_dm; } -template static __device__ __forceinline__ void load_tiles_q5_1( +template __device__ __forceinline__ void load_tiles_q5_1( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { (void)x_qh; (void)x_sc; @@ -2627,7 +2632,7 @@ template static __device__ __forceinlin } } -static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( +__device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { (void)x_qh; (void)x_sc; @@ -2647,7 +2652,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); } -static __device__ __forceinline__ float vec_dot_q8_0_q8_1( +__device__ __forceinline__ float vec_dot_q8_0_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; @@ -2664,7 +2669,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); } -template static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { +template __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { (void)x_qh; (void)x_sc; __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; @@ -2674,7 +2679,7 @@ template static __device__ __forceinline__ void allocate_tiles_q8_0( *x_dm = (half2 *) tile_x_d; } -template static __device__ __forceinline__ void load_tiles_q8_0( +template __device__ __forceinline__ void load_tiles_q8_0( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { (void)x_qh; (void)x_sc; @@ -2720,7 +2725,7 @@ template static __device__ __forceinlin } } -static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( + __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { (void)x_qh; (void)x_sc; @@ -2733,7 +2738,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); } -static __device__ __forceinline__ float vec_dot_q2_K_q8_1( + __device__ __forceinline__ float vec_dot_q2_K_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { const block_q2_K * bq2_K = (const block_q2_K *) vbq; @@ -2756,7 +2761,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1( return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); } -template static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { +template __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { (void)x_qh; __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; @@ -2768,7 +2773,7 @@ template static __device__ __forceinline__ void allocate_tiles_q2_K( *x_sc = tile_x_sc; } -template static __device__ __forceinline__ void load_tiles_q2_K( +template __device__ __forceinline__ void load_tiles_q2_K( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { (void)x_qh; @@ -2826,7 +2831,7 @@ template static __device__ __forceinlin } } -static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( +__device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { (void)x_qh; @@ -2851,7 +2856,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); } -static __device__ __forceinline__ float vec_dot_q3_K_q8_1( +__device__ __forceinline__ float vec_dot_q3_K_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { const block_q3_K * bq3_K = (const block_q3_K *) vbq; @@ -2878,7 +2883,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); } -template static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { +template __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K]; @@ -2891,7 +2896,7 @@ template static __device__ __forceinline__ void allocate_tiles_q3_K( *x_sc = tile_x_sc; } -template static __device__ __forceinline__ void load_tiles_q3_K( +template __device__ __forceinline__ void load_tiles_q3_K( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { @@ -2975,7 +2980,7 @@ template static __device__ __forceinlin } } -static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( +__device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { @@ -3004,7 +3009,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); } -static __device__ __forceinline__ float vec_dot_q4_K_q8_1( +__device__ __forceinline__ float vec_dot_q4_K_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { #ifndef GGML_QKK_64 @@ -3098,7 +3103,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( #endif } -template static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { +template __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { (void)x_qh; __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; @@ -3110,7 +3115,7 @@ template static __device__ __forceinline__ void allocate_tiles_q4_K( *x_sc = tile_x_sc; } -template static __device__ __forceinline__ void load_tiles_q4_K( +template __device__ __forceinline__ void load_tiles_q4_K( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { (void)x_qh; @@ -3180,7 +3185,7 @@ template static __device__ __forceinlin } } -static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( +__device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { (void)x_qh; @@ -3192,7 +3197,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); } -static __device__ __forceinline__ float vec_dot_q5_K_q8_1( +__device__ __forceinline__ float vec_dot_q5_K_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { #ifndef GGML_QKK_64 @@ -3282,7 +3287,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( #endif } -template static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { +template __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { (void)x_qh; __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; @@ -3294,7 +3299,7 @@ template static __device__ __forceinline__ void allocate_tiles_q5_K( *x_sc = tile_x_sc; } -template static __device__ __forceinline__ void load_tiles_q5_K( +template __device__ __forceinline__ void load_tiles_q5_K( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { (void)x_qh; @@ -3375,7 +3380,7 @@ template static __device__ __forceinlin } } -static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( +__device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { (void)x_qh; @@ -3388,7 +3393,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); } -static __device__ __forceinline__ float vec_dot_q6_K_q8_1( +__device__ __forceinline__ float vec_dot_q6_K_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { const block_q6_K * bq6_K = (const block_q6_K *) vbq; @@ -3414,7 +3419,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); } -template static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { +template __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { (void)x_qh; __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; @@ -3426,7 +3431,7 @@ template static __device__ __forceinline__ void allocate_tiles_q6_K( *x_sc = tile_x_sc; } -template static __device__ __forceinline__ void load_tiles_q6_K( +template __device__ __forceinline__ void load_tiles_q6_K( const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { (void)x_qh; @@ -3498,7 +3503,7 @@ template static __device__ __forceinlin } } -static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( +__device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { (void)x_qh; @@ -3515,7 +3520,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( template -static __device__ __forceinline__ void mul_mat_q( +__device__ __forceinline__ void mul_mat_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { @@ -3642,7 +3647,7 @@ static __device__ __forceinline__ void mul_mat_q( #define MMQ_Y_Q4_0_PASCAL 64 #define NWARPS_Q4_0_PASCAL 8 -template static __global__ void +template __global__ void #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q4_0_RDNA2, 2) @@ -3709,7 +3714,7 @@ template static __global__ void #define MMQ_Y_Q4_1_PASCAL 64 #define NWARPS_Q4_1_PASCAL 8 -template static __global__ void +template __global__ void #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2) @@ -3778,7 +3783,7 @@ template static __global__ void #define MMQ_Y_Q5_0_PASCAL 64 #define NWARPS_Q5_0_PASCAL 8 -template static __global__ void +template __global__ void #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q5_0_RDNA2, 2) @@ -3845,7 +3850,7 @@ template static __global__ void #define MMQ_Y_Q5_1_PASCAL 64 #define NWARPS_Q5_1_PASCAL 8 -template static __global__ void +template __global__ void #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q5_1_RDNA2, 2) @@ -3912,7 +3917,7 @@ mul_mat_q5_1( #define MMQ_Y_Q8_0_PASCAL 64 #define NWARPS_Q8_0_PASCAL 8 -template static __global__ void +template __global__ void #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q8_0_RDNA2, 2) @@ -3979,7 +3984,7 @@ template static __global__ void #define MMQ_Y_Q2_K_PASCAL 64 #define NWARPS_Q2_K_PASCAL 8 -template static __global__ void +template __global__ void #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q2_K_RDNA2, 2) @@ -4046,7 +4051,7 @@ mul_mat_q2_K( #define MMQ_Y_Q3_K_PASCAL 64 #define NWARPS_Q3_K_PASCAL 8 -template static __global__ void +template __global__ void #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2) @@ -4115,7 +4120,7 @@ template static __global__ void #define MMQ_Y_Q4_K_PASCAL 64 #define NWARPS_Q4_K_PASCAL 8 -template static __global__ void +template __global__ void #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2) @@ -4184,7 +4189,7 @@ template static __global__ void #define MMQ_Y_Q5_K_PASCAL 64 #define NWARPS_Q5_K_PASCAL 8 -template static __global__ void +template __global__ void #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q5_K_RDNA2, 2) @@ -4251,7 +4256,7 @@ mul_mat_q5_K( #define MMQ_Y_Q6_K_PASCAL 64 #define NWARPS_Q6_K_PASCAL 8 -template static __global__ void +template __global__ void #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA3) || defined(RDNA2) __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2) @@ -4302,7 +4307,7 @@ template static __global__ void } template -static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) { +__global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) { const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row >= nrows) { @@ -4340,7 +4345,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * } template -static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { +__global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { // qk = quantized weights per x block // qr = number of quantized weights per data value in x block const int row = blockIdx.x*blockDim.y + threadIdx.y; @@ -4407,7 +4412,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons } } -static __global__ void mul_mat_p021_f16_f32( +__global__ void mul_mat_p021_f16_f32( const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) { @@ -4457,7 +4462,7 @@ static __global__ void mul_mat_p021_f16_f32( } } -static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous +__global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) { @@ -4503,21 +4508,21 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous } } -static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { +__device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { const float * xi = (const float *) cxi; float * dsti = (float *) cdsti; *dsti = *xi; } -static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { +__device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { const float * xi = (const float *) cxi; half * dsti = (half *) cdsti; *dsti = __float2half(*xi); } -static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) { +__device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) { const half * xi = (const half *) cxi; half * dsti = (half *) cdsti; @@ -4525,7 +4530,7 @@ static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) { } template -static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, +__global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -4549,7 +4554,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } -static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) { +__device__ float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); } @@ -4560,7 +4565,7 @@ struct rope_corr_dims { // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static __device__ void rope_yarn( +__device__ void rope_yarn( float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, float * cos_theta, float * sin_theta ) { @@ -4580,7 +4585,7 @@ static __device__ void rope_yarn( // rope == RoPE == rotary positional embedding template -static __global__ void rope( +__global__ void rope( const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims ) { @@ -4608,7 +4613,7 @@ static __global__ void rope( } template -static __global__ void rope_neox( +__global__ void rope_neox( const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims ) { @@ -4638,7 +4643,7 @@ static __global__ void rope_neox( dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_glm_f32( +__global__ void rope_glm_f32( const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, int n_ctx ) { @@ -4678,7 +4683,7 @@ static __global__ void rope_glm_f32( dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; } -static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, +__global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, const int n_heads_log2_floor, const float m0, const float m1) { const int col = blockDim.x*blockIdx.x + threadIdx.x; @@ -4701,7 +4706,7 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols, dst[i] = col * m_k + x[i]; } -static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { +__global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { const int col = blockDim.y*blockIdx.y + threadIdx.y; const int row = blockDim.x*blockIdx.x + threadIdx.x; @@ -4716,7 +4721,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int // the CUDA soft max implementation differs from the CPU implementation // instead of doubles floats are used -static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { +__global__ void soft_max_f32(const float * x, float * dst, const int ncols) { const int row = blockDim.x*blockIdx.x + threadIdx.x; const int block_size = blockDim.y; const int tid = threadIdx.y; @@ -4757,7 +4762,7 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol } } -static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { +__global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -4767,7 +4772,7 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale dst[i] = scale * x[i]; } -static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) { +__global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -4777,7 +4782,7 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min, dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); } -static __global__ void im2col_f32_f16( + __global__ void im2col_f32_f16( const float * x, half * dst, int ofs0, int ofs1, int IW, int IH, int CHW, int s0, int s1, int p0, int p1, int d0, int d1) { @@ -4797,54 +4802,54 @@ static __global__ void im2col_f32_f16( } template -static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) { +void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); const dim3 block_nums(block_num_x, nrows, 1); k_get_rows<<>>(x, y, dst, ncols); } -static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { +void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; add_f32<<>>(x, y, dst, kx, ky); } -static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { +void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; add_f16_f32_f16<<>>(x, y, dst, k); } -static void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) { +void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; add_f16_f32_f32<<>>(x, y, dst, k); } -static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { +void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; mul_f32<<>>(x, y, dst, kx, ky); } -static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { +void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; gelu_f32<<>>(x, dst, k); } -static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { +void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; silu_f32<<>>(x, dst, k); } -static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { +void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; relu_f32<<>>(x, dst, k); } -static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { +void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE; sqr_f32<<>>(x, dst, k); } -static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); @@ -4855,7 +4860,7 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i } } -static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { +void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); @@ -4866,7 +4871,7 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con } } -static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) { +void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) { const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, ky, 1); const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1); @@ -4874,37 +4879,37 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con } template -static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<<>>(vx, y, k); } template -static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<<>>(vx, y, k); } template -static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<<>>(vx, y, k); } template -static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<<>>(vx, y, k); } template -static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<<>>(vx, y, k); } template -static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q2_K<<>>(vx, y); @@ -4914,7 +4919,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q3_K<<>>(vx, y); @@ -4924,13 +4929,13 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_q4_K<<>>(vx, y); } template -static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q5_K<<>>(vx, y); @@ -4940,7 +4945,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q6_K<<>>(vx, y); @@ -4949,7 +4954,7 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu #endif } -static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead @@ -4959,7 +4964,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, <<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -4968,7 +4973,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, <<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -4977,7 +4982,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, <<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -4986,7 +4991,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, <<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -4995,7 +5000,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, <<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 const int block_num_y = (nrows + ny - 1) / ny; @@ -5004,7 +5009,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; @@ -5013,7 +5018,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; @@ -5022,13 +5027,13 @@ static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols, nrows); } -static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const dim3 block_dims(32, 1, 1); dequantize_mul_mat_vec_q5_k<<>>(vx, y, dst, ncols); } -static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; @@ -5037,7 +5042,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); } -static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -5046,7 +5051,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK4_1 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -5055,7 +5060,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK5_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -5064,7 +5069,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK5_1 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -5073,7 +5078,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK8_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -5082,7 +5087,7 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -5091,7 +5096,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -5100,7 +5105,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -5109,7 +5114,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -5118,7 +5123,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } -static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -5127,17 +5132,17 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } -static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { +void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<1, 1, convert_f16><<>>(vx, y, k); } -static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) { +void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; dequantize_block<1, 1, convert_f32><<>>(vx, y, k); } -static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -5146,7 +5151,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa <<>>(vx, y, dst, ncols, nrows); } -static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { +to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; @@ -5175,7 +5180,7 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { } } -static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { +to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; @@ -5204,7 +5209,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } } -static void ggml_mul_mat_q4_0_q8_1_cuda( +void ggml_mul_mat_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -5249,7 +5254,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( } } -static void ggml_mul_mat_q4_1_q8_1_cuda( +void ggml_mul_mat_q4_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -5294,7 +5299,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( } } -static void ggml_mul_mat_q5_0_q8_1_cuda( +void ggml_mul_mat_q5_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -5339,7 +5344,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( } } -static void ggml_mul_mat_q5_1_q8_1_cuda( +void ggml_mul_mat_q5_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -5384,7 +5389,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( } } -static void ggml_mul_mat_q8_0_q8_1_cuda( +void ggml_mul_mat_q8_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -5429,7 +5434,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( } } -static void ggml_mul_mat_q2_K_q8_1_cuda( +void ggml_mul_mat_q2_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -5474,7 +5479,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( } } -static void ggml_mul_mat_q3_K_q8_1_cuda( +void ggml_mul_mat_q3_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -5522,7 +5527,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( #endif } -static void ggml_mul_mat_q4_K_q8_1_cuda( +void ggml_mul_mat_q4_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -5567,7 +5572,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( } } -static void ggml_mul_mat_q5_K_q8_1_cuda( +void ggml_mul_mat_q5_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -5612,7 +5617,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( } } -static void ggml_mul_mat_q6_K_q8_1_cuda( +void ggml_mul_mat_q6_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -5657,7 +5662,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda( } } -static void ggml_mul_mat_p021_f16_f32_cuda( +void ggml_mul_mat_p021_f16_f32_cuda( const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, cudaStream_t stream) { @@ -5666,7 +5671,7 @@ static void ggml_mul_mat_p021_f16_f32_cuda( mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y); } -static void ggml_mul_mat_vec_nc_f16_f32_cuda( +void ggml_mul_mat_vec_nc_f16_f32_cuda( const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) { @@ -5676,7 +5681,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda( (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x); } -static void ggml_cpy_f32_f32_cuda( +void ggml_cpy_f32_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { @@ -5686,7 +5691,7 @@ static void ggml_cpy_f32_f32_cuda( (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); } -static void ggml_cpy_f32_f16_cuda( +void ggml_cpy_f32_f16_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { @@ -5696,7 +5701,7 @@ static void ggml_cpy_f32_f16_cuda( (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); } -static void ggml_cpy_f16_f16_cuda( +void ggml_cpy_f16_f16_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { @@ -5706,18 +5711,18 @@ static void ggml_cpy_f16_f16_cuda( (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); } -static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { +void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; scale_f32<<>>(x, dst, scale, k); } -static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) { +void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE; clamp_f32<<>>(x, dst, min, max, k); } template -static void rope_cuda( +void rope_cuda( const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream ) { @@ -5737,7 +5742,7 @@ static void rope_cuda( } template -static void rope_neox_cuda( +void rope_neox_cuda( const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream ) { @@ -5756,7 +5761,7 @@ static void rope_neox_cuda( } } -static void rope_glm_f32_cuda( +void rope_glm_f32_cuda( const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, int n_ctx, cudaStream_t stream ) { @@ -5767,7 +5772,7 @@ static void rope_glm_f32_cuda( rope_glm_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx); } -static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, +void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int k_rows, const int n_heads_log2_floor, const float m0, const float m1, cudaStream_t stream) { const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1); @@ -5776,20 +5781,20 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const alibi_f32<<>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1); } -static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { +void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1); const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; const dim3 block_nums(nrows_x, block_num_x, 1); diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past); } -static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { +void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { const dim3 block_dims(1, WARP_SIZE, 1); const dim3 block_nums(nrows_x, 1, 1); soft_max_f32<<>>(x, dst, ncols_x); } -static void im2col_f32_f16_cuda(const float * x, half * dst, +void im2col_f32_f16_cuda(const float * x, half * dst, int OH, int IW, int IH, int OW, int IC, int KH, int KW, int N, int ofs0, int ofs1, int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) { @@ -5820,10 +5825,10 @@ struct cuda_buffer { size_t size = 0; }; -static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS]; -static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; +cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS]; +std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; -static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { scoped_spin_lock lock(g_cuda_pool_lock); int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -5877,7 +5882,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { return ptr; } -static void ggml_cuda_pool_free(void * ptr, size_t size) { +void ggml_cuda_pool_free(void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -5894,7 +5899,9 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) { CUDA_CHECK(cudaFree(ptr)); } -static bool g_cublas_loaded = false; +bool g_cublas_loaded = false; + +END_ANONYMOUS_NAMESPACE bool ggml_cublas_loaded(void) { return g_cublas_loaded; @@ -6016,7 +6023,9 @@ void ggml_cuda_host_free(void * ptr) { CUDA_CHECK(cudaFreeHost(ptr)); } -static cudaError_t ggml_cuda_cpy_tensor_2d( +namespace { + +cudaError_t ggml_cuda_cpy_tensor_2d( void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { cudaMemcpyKind kind; @@ -6063,7 +6072,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( return cudaSuccess; } -static void ggml_cuda_op_repeat( +void ggml_cuda_op_repeat( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) { // guaranteed to be an integer due to the check in ggml_can_repeat @@ -6120,7 +6129,7 @@ static void ggml_cuda_op_repeat( (void) src1_d; } -static void ggml_cuda_op_get_rows( +void ggml_cuda_op_get_rows( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) { @@ -6164,6 +6173,8 @@ static void ggml_cuda_op_get_rows( } } +END_ANONYMOUS_NAMESPACE + inline void ggml_cuda_op_add( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -6359,7 +6370,9 @@ inline void ggml_cuda_op_mul_mat_q( (void) src1_ddf_i; } -static int64_t get_row_rounding(ggml_type type) { +namespace { + +int64_t get_row_rounding(ggml_type type) { int64_t min_compute_capability = INT_MAX; int64_t max_compute_capability = INT_MIN; for (int64_t id = 0; id < g_device_count; ++id) { @@ -6420,6 +6433,8 @@ static int64_t get_row_rounding(ggml_type type) { #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) } +END_ANONYMOUS_NAMESPACE + inline void ggml_cuda_op_mul_mat_vec_q( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, @@ -6893,7 +6908,9 @@ inline void ggml_cuda_op_clamp( (void) src1_dd; } -static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) { +namespace { + +void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) { const int64_t nrows0 = ggml_nrows(src0); const bool use_src1 = src1 != nullptr; @@ -6970,7 +6987,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s } } -static void ggml_cuda_set_peer_access(const int n_tokens) { +void ggml_cuda_set_peer_access(const int n_tokens) { static bool peer_access_enabled = false; const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE; @@ -7007,7 +7024,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens) { peer_access_enabled = enable_peer_access; } -static void ggml_cuda_op_mul_mat( +void ggml_cuda_op_mul_mat( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, const bool convert_src1_to_q8_1) { @@ -7308,46 +7325,48 @@ static void ggml_cuda_op_mul_mat( } } -static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_repeat); } -static void ggml_cuda_get_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_get_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_get_rows); } -static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add); } -static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul); } -static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu); } -static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu); } -static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu); } -static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr); } -static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm); } -static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm); } +END_ANONYMOUS_NAMESPACE + bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { if (!g_cublas_loaded) { return false; } @@ -7363,7 +7382,9 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te (ne0 >= 32 && ne1 >= 32 && ne10 >= 32); } -static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +namespace { + +void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation @@ -7392,7 +7413,7 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); } -static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ +void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); @@ -7427,7 +7448,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); } -__global__ static void k_compute_batched_ptrs( +__global__ void k_compute_batched_ptrs( const half * src0_as_f16, const half * src1_as_f16, half * dst_f16, const void ** ptrs_src, void ** ptrs_dst, int ne12, int ne13, @@ -7451,7 +7472,7 @@ __global__ static void k_compute_batched_ptrs( ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2; } -static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); @@ -7601,7 +7622,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_cuda_pool_free(dst_f16, dst_as); } -static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && (src1->backend == GGML_BACKEND_GPU) && @@ -7674,15 +7695,15 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } } -static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } -static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp); } -static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -7735,38 +7756,40 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg (void) dst; } -static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_cpy(src0, dst, nullptr); (void) src1; } -static void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf); } -static void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max); } -static void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope); } -static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi); } -static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col); } -static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { (void) src0; (void) src1; (void) dst; } +END_ANONYMOUS_NAMESPACE + void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { const int64_t nrows = ggml_nrows(tensor); @@ -7867,10 +7890,12 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) { delete extra; } -static ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr; -static size_t g_temp_tensor_extra_index = 0; +namespace { -static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { +ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr; +size_t g_temp_tensor_extra_index = 0; + +ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { if (g_temp_tensor_extras == nullptr) { g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_CUDA_MAX_NODES]; } @@ -7883,7 +7908,7 @@ static ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { return extra; } -static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) { +void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) { if (scratch && g_scratch_size == 0) { return; } @@ -7956,6 +7981,8 @@ static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scra tensor->extra = extra; } +END_ANONYMOUS_NAMESPACE + void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) { if (g_scratch_size == 0) { return; @@ -8179,13 +8206,15 @@ void ggml_cuda_get_device_description(int device, char * description, size_t des struct ggml_backend_context_cuda { }; -static const char * ggml_backend_cuda_name(ggml_backend_t backend) { +namespace { + +const char * ggml_backend_cuda_name(ggml_backend_t backend) { return GGML_CUDA_NAME; UNUSED(backend); } -static void ggml_backend_cuda_free(ggml_backend_t backend) { +void ggml_backend_cuda_free(ggml_backend_t backend) { ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context; delete cuda_ctx; delete backend; @@ -8215,18 +8244,18 @@ struct ggml_backend_buffer_context_cuda { } }; -static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { +void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; CUDA_CHECK(cudaFree(ctx->device)); delete ctx; } -static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) { +void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) { ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; return ctx->device; } -static size_t ggml_backend_cuda_buffer_get_alloc_size(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +size_t ggml_backend_cuda_buffer_get_alloc_size(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { int64_t row_low = 0; int64_t row_high = ggml_nrows(tensor); int64_t nrows_split = row_high - row_low; @@ -8247,7 +8276,7 @@ static size_t ggml_backend_cuda_buffer_get_alloc_size(ggml_backend_buffer_t buff UNUSED(buffer); } -static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { ggml_backend_buffer_context_cuda * ctx = (ggml_backend_buffer_context_cuda *)buffer->context; if (tensor->view_src != NULL && tensor->view_offs == 0) { @@ -8281,7 +8310,7 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g UNUSED(buffer); } -static struct ggml_backend_buffer_i cuda_backend_buffer_interface = { +struct ggml_backend_buffer_i cuda_backend_buffer_interface = { /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer, /* .get_base = */ ggml_backend_cuda_buffer_get_base, /* .get_alloc_size = */ ggml_backend_cuda_buffer_get_alloc_size, @@ -8289,7 +8318,7 @@ static struct ggml_backend_buffer_i cuda_backend_buffer_interface = { /* .free_tensor = */ NULL, }; -static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) { +ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backend, size_t size) { ggml_cuda_set_device(g_main_device); ggml_backend_buffer_context_cuda * ctx = new ggml_backend_buffer_context_cuda; @@ -8302,12 +8331,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_alloc_buffer(ggml_backend_t backe return ggml_backend_buffer_init(backend, cuda_backend_buffer_interface, ctx, size); } -static size_t ggml_backend_cuda_get_alignment(ggml_backend_t backend) { +size_t ggml_backend_cuda_get_alignment(ggml_backend_t backend) { return 128; UNUSED(backend); } -static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { +void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); @@ -8317,7 +8346,7 @@ static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tens UNUSED(backend); } -static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { +void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU); @@ -8327,13 +8356,13 @@ static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggm UNUSED(backend); } -static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { +void ggml_backend_cuda_synchronize(ggml_backend_t backend) { CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0])); UNUSED(backend); } -static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, ggml_cgraph * cgraph) { +ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backend_t backend, ggml_cgraph * cgraph) { GGML_ASSERT(!"not implemented"); return nullptr; @@ -8342,21 +8371,21 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen UNUSED(cgraph); } -[[noreturn]] static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { +[[noreturn]] void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { GGML_ASSERT(!"not implemented"); UNUSED(backend); UNUSED(plan); } -[[noreturn]] static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { +[[noreturn]] void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { GGML_ASSERT(!"not implemented"); UNUSED(backend); UNUSED(plan); } -static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { +void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_cuda_set_device(g_main_device); ggml_compute_params params = {}; @@ -8408,7 +8437,7 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph UNUSED(backend); } -static ggml_backend_i cuda_backend_i = { +ggml_backend_i cuda_backend_i = { /* .get_name = */ ggml_backend_cuda_name, /* .free = */ ggml_backend_cuda_free, /* .alloc_buffer = */ ggml_backend_cuda_alloc_buffer, @@ -8425,6 +8454,8 @@ static ggml_backend_i cuda_backend_i = { /* .supports_op = */ nullptr, }; +END_ANONYMOUS_NAMESPACE + ggml_backend_t ggml_backend_cuda_init() { ggml_init_cublas(); // TODO: remove from ggml.c