diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 920466aae..4c9e21429 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -13,6 +13,8 @@ #include "ggml-cuda.h" #include "ggml.h" +#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products + #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif @@ -74,7 +76,7 @@ typedef void (*ggml_cuda_op_t)( #define QK4_0 32 #define QR4_0 2 -#define QI4_0 4 +#define QI4_0 (QK4_0 / (4 * QR4_0)) typedef struct { half d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants @@ -83,7 +85,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 #define QK4_1 32 #define QR4_1 2 -#define QI4_1 4 +#define QI4_1 (QK4_1 / (4 * QR4_1)) typedef struct { half d; // delta half m; // min @@ -93,7 +95,7 @@ static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong #define QK5_0 32 #define QR5_0 2 -#define QI5_0 4 +#define QI5_0 (QK5_0 / (4 * QR5_0)) typedef struct { half d; // delta uint8_t qh[4]; // 5-th bit of quants @@ -103,7 +105,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5 #define QK5_1 32 #define QR5_1 2 -#define QI5_1 4 +#define QI5_1 (QK5_1 / (4 * QR5_1)) typedef struct { half d; // delta half m; // min @@ -114,7 +116,7 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + #define QK8_0 32 #define QR8_0 1 -#define QI8_0 8 +#define QI8_0 (QK8_0 / (4 * QR8_0)) typedef struct { half d; // delta int8_t qs[QK8_0]; // quants @@ -123,7 +125,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo #define QK8_1 32 #define QR8_1 1 -#define QI8_1 8 +#define QI8_1 (QK8_1 / (4 * QR8_1)) typedef struct { half d; // delta half s; // unquantized sum @@ -143,6 +145,8 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_ #define K_SCALE_SIZE 12 #endif +#define QR2_K 4 +#define QI2_K (QK_K / (4*QR2_K)) typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits uint8_t qs[QK_K/4]; // quants @@ -151,6 +155,8 @@ typedef struct { } block_q2_K; static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); +#define QR3_K 4 +#define QI3_K (QK_K / (4*QR3_K)) typedef struct { uint8_t hmask[QK_K/8]; // quants - high bit uint8_t qs[QK_K/4]; // quants - low 2 bits @@ -163,6 +169,8 @@ typedef struct { } block_q3_K; //static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); +#define QR4_K 2 +#define QI4_K (QK_K / (4*QR4_K)) #ifdef GGML_QKK_64 typedef struct { half d[2]; // super-block scales/mins @@ -180,6 +188,8 @@ typedef struct { static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); #endif +#define QR5_K 2 +#define QI5_K (QK_K / (4*QR5_K)) #ifdef GGML_QKK_64 typedef struct { half d; // super-block scale @@ -199,6 +209,8 @@ typedef struct { static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); #endif +#define QR6_K 2 +#define QI6_K (QK_K / (4*QR6_K)) typedef struct { uint8_t ql[QK_K/2]; // quants, lower 4 bits uint8_t qh[QK_K/4]; // quants, upper 2 bits @@ -1271,8 +1283,9 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __ y[iybs + iqs + y_offset] = v.y; } -static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics +static __device__ __forceinline__ float vec_dot_q4_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; int vi; @@ -1293,11 +1306,12 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restric return sumi*d; #else return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= 610 +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics +static __device__ __forceinline__ float vec_dot_q4_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]); @@ -1318,11 +1332,12 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restric return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block #else return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= 610 +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics +static __device__ __forceinline__ float vec_dot_q5_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; int qs; @@ -1353,11 +1368,12 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restric return sumi*d; #else return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= 610 +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics +static __device__ __forceinline__ float vec_dot_q5_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]); @@ -1387,11 +1403,12 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restric return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block #else return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= 610 +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { -#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics +static __device__ __forceinline__ float vec_dot_q8_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; int vi; @@ -1406,7 +1423,220 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restric return sumi*d; #else return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= 610 +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const block_q2_K * bq2_K = (const block_q2_K *) vbq; + + const int bq8_offset = QR2_K * (iqs / QI8_1); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + + const float d = bq2_K->d; + const float dmin = bq2_K->dmin; + + const int v = *((int *) &bq2_K->qs[sizeof(int) * iqs]); + + for (int i = 0; i < QR2_K; ++i) { + const int sc = bq2_K->scales[scale_offset + 2*i]; + + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + const float d8i = bq8i->d; + + const int vi = (v >> (2*i)) & 0x03030303; + const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); + + sumf_d += d8i * (__dp4a(vi, ui, 0) * (sc & 0xF)); // SIMD dot product + sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * (sc >> 4)); // multiply constant q2_K part with sum of q8_1 values + } + + return d*sumf_d - dmin*sumf_m; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const block_q3_K * bq3_K = (const block_q3_K *) vbq; + + const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + float sumf = 0.0f; + + const float d = bq3_K->d; + + int vl; + memcpy(&vl, &bq3_K->qs[sizeof(int) * iqs], sizeof(int)); + + int vh; + memcpy(&vh, &bq3_K->hmask[sizeof(int) * (iqs % (QI3_K/2))], sizeof(int)); + vh = ~vh; // invert the mask so that a 0/1 results in 4/0 being subtracted + vh >>= bq8_offset; + + for (int i = 0; i < QR3_K; ++i) { + const int isc = scale_offset + 2*i; + + const int isc_low = isc % (QK_K/32); + const int sc_shift_low = 4 * (isc / (QK_K/32)); + const int sc_low = (bq3_K->scales[isc_low] >> sc_shift_low) & 0xF; + + const int isc_high = isc % (QK_K/64); + const int sc_shift_high = 2 * (isc / (QK_K/64)); + const int sc_high = ((bq3_K->scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; + + const int sc = (sc_low | sc_high) - 32; + + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); + const float d8i = bq8i->d; + + const int vil = (vl >> (2*i)) & 0x03030303; + + const int vih = ((vh >> i) << 2) & 0x04040404; + + const int vi = __vsubss4(vil, vih); + + sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product + } + + return d*sumf; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const block_q4_K * bq4_K = (const block_q4_K *) vbq; + + const int bq8_offset = QR4_K * (iqs / QI8_1); + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + + const float d = bq4_K->d; + const float dmin = bq4_K->dmin; + + const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]); + + for (int i = 0; i < QR4_K; ++i) { + const int isc = bq8_offset + i; + + uint8_t sc, m; + get_scale_min_k4(isc, bq4_K->scales, sc, m); + + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); + const float d8i = bq8i->d; + + const int vi = (v >> (4*i)) & 0x0F0F0F0F; + + sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product + sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q4_K with sum of q8_1 values + } + + return d*sumf_d - dmin*sumf_m; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const block_q5_K * bq5_K = (const block_q5_K *) vbq; + + const int bq8_offset = QR5_K * (iqs / QI8_1); + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + + const float d = bq5_K->d; + const float dmin = bq5_K->dmin; + + const int vl = *((int *) &bq5_K->qs[sizeof(int) * iqs]); + + const int vh = (*((int *) &bq5_K->qh[sizeof(int) * (iqs % (QI5_K/4))])) >> bq8_offset; + + for (int i = 0; i < QR5_K; ++i) { + const int isc = bq8_offset + i; + + uint8_t sc, m; + get_scale_min_k4(isc, bq5_K->scales, sc, m); + + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); + const float d8i = bq8i->d; + + const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + + const int vih = ((vh >> i) << 4) & 0x10101010; + + const int vi = vil | vih; + + sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product + sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q5_K with sum of q8_1 values + } + + return d*sumf_d - dmin*sumf_m; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const block_q6_K * bq6_K = (const block_q6_K *) vbq; + + const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); + const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); + const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); + + float sumf = 0.0f; + + const float d = bq6_K->d; + + int vl; + memcpy(&vl, &bq6_K->ql[sizeof(int) * iqs], sizeof(int)); + + int vh; + memcpy(&vh, &bq6_K->qh[sizeof(int) * ((QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4))], sizeof(int)); + + for (int i = 0; i < QR6_K; ++i) { + const int sc = bq6_K->scales[scale_offset + 4*i]; + + const block_q8_1 * bq8i = bq8_1 + bq8_offset + 2*i; + const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % (QI8_1))]); + const float d8i = bq8i->d; + + const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + + const int vih = ((vh >> (vh_shift + 4*i)) << 4) & 0x30303030; + + const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 + + sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product + } + + return d*sumf; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } template @@ -1429,7 +1659,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * for (int i = 0; i < blocks_per_row; i += blocks_per_warp) { const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index - const int iby = i + threadIdx.x / qi; // y block index + const int iby = (i + threadIdx.x / qi) * qk/QK8_1; // y block index that aligns with ibx const int iqs = threadIdx.x % qi; // x block quant index when casting the quants to int @@ -1962,7 +2192,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f } static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); @@ -1971,7 +2201,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * } static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(ncols % QK4_1 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); @@ -1980,7 +2210,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * } static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(ncols % QK5_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); @@ -1989,7 +2219,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * } static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(ncols % QK5_1 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); @@ -1998,7 +2228,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * } static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + GGML_ASSERT(ncols % QK8_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); @@ -2006,6 +2236,51 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } +static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<1, 1, convert_f16><<>>(vx, y, k); @@ -2494,13 +2769,22 @@ inline void ggml_cuda_op_mul_mat_vec( int id; CUDA_CHECK(cudaGetDevice(&id)); - const bool mul_mat_vec_q_implemented = src0->type == GGML_TYPE_Q4_0 || + bool mul_mat_vec_q_implemented = + src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0; +#if QK_K == 256 + mul_mat_vec_q_implemented = mul_mat_vec_q_implemented || + src0->type == GGML_TYPE_Q2_K || + src0->type == GGML_TYPE_Q3_K || + src0->type == GGML_TYPE_Q4_K || + src0->type == GGML_TYPE_Q5_K || + src0->type == GGML_TYPE_Q6_K; +#endif // QK_K == 256 - const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 610 && mul_mat_vec_q_implemented; + const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= MIN_CC_DP4A && mul_mat_vec_q_implemented; #endif if (use_mul_mat_vec_q) { @@ -2526,6 +2810,21 @@ inline void ggml_cuda_op_mul_mat_vec( case GGML_TYPE_Q8_0: mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q4_K: + mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q6_K: + mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + break; default: GGML_ASSERT(false); break;