diff --git a/CMakeLists.txt b/CMakeLists.txt index cc7560a7a..ffda74a70 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,6 +75,7 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_METAL "llama: use Metal" OFF) option(LLAMA_K_QUANTS "llama: use k-quants" ON) +option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) @@ -225,6 +226,14 @@ if (LLAMA_BLAS) endif() endif() +if (LLAMA_K_QUANTS) + set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h) + add_compile_definitions(GGML_USE_K_QUANTS) + if (LLAMA_QKK_64) + add_compile_definitions(GGML_QKK_64) + endif() +endif() + if (LLAMA_CUBLAS) cmake_minimum_required(VERSION 3.17) @@ -289,11 +298,6 @@ if (LLAMA_METAL) ) endif() -if (LLAMA_K_QUANTS) - set(GGML_SOURCES_EXTRA ${GGML_SOURCES_EXTRA} k_quants.c k_quants.h) - add_compile_definitions(GGML_USE_K_QUANTS) -endif() - if (LLAMA_CLBLAST) find_package(CLBlast) if (CLBlast_FOUND) diff --git a/Makefile b/Makefile index 5dd676fad..bda11791d 100644 --- a/Makefile +++ b/Makefile @@ -43,8 +43,11 @@ endif # keep standard at C11 and C++11 # -Ofast tends to produce faster code, but may not be available for some compilers. -#OPT = -Ofast +ifdef LLAMA_FAST +OPT = -Ofast +else OPT = -O3 +endif CFLAGS = -I. $(OPT) -std=c11 -fPIC CXXFLAGS = -I. -I./examples $(OPT) -std=c++11 -fPIC LDFLAGS = @@ -131,6 +134,10 @@ ifndef LLAMA_NO_K_QUANTS CFLAGS += -DGGML_USE_K_QUANTS CXXFLAGS += -DGGML_USE_K_QUANTS OBJS += k_quants.o +ifdef LLAMA_QKK_64 + CFLAGS += -DGGML_QKK_64 + CXXFLAGS += -DGGML_QKK_64 +endif endif ifndef LLAMA_NO_ACCELERATE diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5e2fbc724..c34e96abf 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -117,7 +117,13 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo //================================= k-quants +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else #define QK_K 256 +#define K_SCALE_SIZE 12 +#endif typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits @@ -128,13 +134,25 @@ typedef struct { static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); typedef struct { - uint8_t hmask[QK_K/8]; - uint8_t qs[QK_K/4]; // nibbles / quants - uint8_t scales[3*QK_K/64]; - half d; + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits +#ifdef GGML_QKK_64 + uint8_t scales[2]; // scales, quantized with 8 bits +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale } block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding"); +//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); +#ifdef GGML_QKK_64 +typedef struct { + half d[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); +#else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins @@ -142,15 +160,26 @@ typedef struct { uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +#endif +#ifdef GGML_QKK_64 typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + half d; // super-block scale + int8_t scales[QK_K/16]; // block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); +#else +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; -static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif typedef struct { uint8_t ql[QK_K/2]; // quants, lower 4 bits @@ -349,13 +378,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { const int i = blockIdx.x; + const block_q2_K * x = (const block_q2_K *) vx; + const int tid = threadIdx.x; +#if QK_K == 256 const int n = tid/32; const int l = tid - 32*n; const int is = 8*n + l/16; - const block_q2_K * x = (const block_q2_K *) vx; - const uint8_t q = x[i].qs[32*n + l]; float * y = yy + i*QK_K + 128*n; @@ -365,21 +395,32 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) { y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); +#else + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const uint8_t q = x[i].qs[il] >> (2*is); + float * y = yy + i*QK_K + 16*is + il; + float dall = x[i].d; + float dmin = x[i].dmin; + y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); +#endif } static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { - int r = threadIdx.x/4; - int i = blockIdx.x; - int tid = r/2; - int is0 = r%2; - int l0 = 16*is0 + 4*(threadIdx.x%4); - int n = tid / 4; - int j = tid - 4*n; - + const int i = blockIdx.x; const block_q3_K * x = (const block_q3_K *) vx; +#if QK_K == 256 + const int r = threadIdx.x/4; + const int tid = r/2; + const int is0 = r%2; + const int l0 = 16*is0 + 4*(threadIdx.x%4); + const int n = tid / 4; + const int j = tid - 4*n; + uint8_t m = 1 << (4*n + j); int is = 8*n + 2*j + is0; int shift = 2*j; @@ -396,9 +437,31 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { const uint8_t * hm = x[i].hmask; for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); +#else + const int tid = threadIdx.x; + const int is = tid/16; // 0 or 1 + const int il = tid%16; // 0...15 + const int im = il/8; // 0...1 + const int in = il%8; // 0...7 + + float * y = yy + i*QK_K + 16*is + il; + + const uint8_t q = x[i].qs[il] >> (2*is); + const uint8_t h = x[i].hmask[in] >> (2*is + im); + const float d = (float)x[i].d; + + if (is == 0) { + y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } else { + y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); + y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); + } +#endif } +#if QK_K == 256 static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { d = q[j] & 63; m = q[j + 4] & 63; @@ -407,19 +470,14 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); } } +#endif static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { const block_q4_K * x = (const block_q4_K *) vx; const int i = blockIdx.x; - //// assume 64 threads - this is very slightly better than the one below - //const int tid = threadIdx.x; - //const int il = tid/16; - //const int ir = tid%16; - //const int is = 2*il; - //const int n = 2; - +#if QK_K == 256 // assume 32 threads const int tid = threadIdx.x; const int il = tid/8; @@ -443,6 +501,15 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) { y[l + 0] = d1 * (q[l] & 0xF) - m1; y[l +32] = d2 * (q[l] >> 4) - m2; } +#else + const int tid = threadIdx.x; + const uint8_t * q = x[i].qs; + float * y = yy + i*QK_K; + const float d = (float)x[i].d[0]; + const float m = (float)x[i].d[1]; + y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); + y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4); +#endif } static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { @@ -450,6 +517,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { const int i = blockIdx.x; +#if QK_K == 256 // assume 64 threads - this is very slightly better than the one below const int tid = threadIdx.x; const int il = tid/16; // il is in 0...3 @@ -476,12 +544,25 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { hm <<= 1; y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +#else + const int tid = threadIdx.x; + const uint8_t q = x[i].qs[tid]; + const int im = tid/8; // 0...3 + const int in = tid%8; // 0...7 + const int is = tid/16; // 0 or 1 + const uint8_t h = x[i].qh[in] >> im; + const float d = x[i].d; + float * y = yy + i*QK_K + tid; + y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); + y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); +#endif } static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { const block_q6_K * x = (const block_q6_K *) vx; const int i = blockIdx.x; +#if QK_K == 256 // assume 64 threads - this is very slightly better than the one below const int tid = threadIdx.x; @@ -501,6 +582,24 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +#else + + // assume 32 threads + const int tid = threadIdx.x; + const int ip = tid/16; // 0 or 1 + const int il = tid - 16*ip; // 0...15 + + float * y = yy + i*QK_K + 16*ip + il; + + const float d = x[i].d; + + const uint8_t ql = x[i].ql[16*ip + il]; + const uint8_t qh = x[i].qh[il] >> (2*ip); + const int8_t * sc = x[i].scales; + + y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32); +#endif } static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { @@ -515,6 +614,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float const block_q2_K * x = (const block_q2_K *)vx + ib0; + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 @@ -528,8 +630,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float const int s_offset = 8*im; const int y_offset = 128*im + l0; - float tmp = 0; // partial sum for thread in warp - uint32_t aux[4]; const uint8_t * d = (const uint8_t *)aux; const uint8_t * m = (const uint8_t *)(aux + 2); @@ -565,6 +665,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float tmp += dall * sum1 - dmin * sum2; } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; + + uint32_t uaux[2]; + const uint8_t * d = (const uint8_t *)uaux; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint32_t * s = (const uint32_t *)x[i].scales; + + uaux[0] = s[0] & 0x0f0f0f0f; + uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; + + const half2 * dh = (const half2 *)&x[i].d; + + const float2 dall = __half22float2(dh[0]); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t ql = q[l]; + sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) + + y[l+16] * d[1] * ((ql >> 2) & 3) + + y[l+32] * d[2] * ((ql >> 4) & 3) + + y[l+48] * d[3] * ((ql >> 6) & 3); + sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; + } + tmp += dall.x * sum1 - dall.y * sum2; + } +#endif // sum up partial sums and write back result __syncthreads(); @@ -573,16 +706,13 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } - if (tid == 0) { + if (threadIdx.x == 0) { dst[row] = tmp; } } static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - const int row = blockIdx.y*blockDim.y + threadIdx.y; if (row > nrows) return; @@ -591,6 +721,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float const block_q3_K * x = (const block_q3_K *)vx + ib0; + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 @@ -610,8 +747,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float const uint16_t s_shift = 4*im; - float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { const float * y = yy + i * QK_K + y_offset; @@ -640,6 +775,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float tmp += d * sum; } +#else + + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 + const int in = offset/8; // 0 or 1 + const int im = offset%8; // 0...7 + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + offset; + const uint8_t * q = x[i].qs + offset; + const uint8_t * s = x[i].scales; + + const float dall = (float)x[i].d; + + float sum = 0; + for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + const uint8_t hl = x[i].hmask[im+l] >> in; + const uint8_t ql = q[l]; + sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) + + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) + + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) + + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); + } + tmp += sum; + } +#endif // sum up partial sums and write back result __syncthreads(); @@ -648,22 +811,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } - if (tid == 0) { + if (threadIdx.x == 0) { dst[row] = tmp; } } static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - const int row = blockIdx.y*blockDim.y + threadIdx.y; if (row > nrows) return; const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; + const block_q4_K * x = (const block_q4_K *)vx + ib0; + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 @@ -683,8 +849,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; - const block_q4_K * x = (const block_q4_K *)vx + ib0; - float tmp = 0; // partial sum for thread in warp for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { @@ -713,6 +877,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin; } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + + const int step = tid * K_QUANTS_PER_ITERATION; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + float tmp = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const float * y = yy + i*QK_K + step; + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + const float d = (float)x[i].d[0]; + const float m = (float)x[i].d[1]; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); + } + tmp += sum; + } + +#endif // sum up partial sums and write back result __syncthreads(); @@ -728,15 +922,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - //const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x; const int num_blocks_per_row = ncols / QK_K; const int ib0 = row*num_blocks_per_row; + const block_q5_K * x = (const block_q5_K *)vx + ib0; + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int tid = threadIdx.x/2; // 0...15 const int ix = threadIdx.x%2; @@ -757,10 +955,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; - const block_q5_K * x = (const block_q5_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += 2) { const uint8_t * ql1 = x[i].qs + q_offset; @@ -793,9 +987,32 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; } tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; - } +#else + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + const int step = tid * K_QUANTS_PER_ITERATION; + const int im = step/8; + const int in = step%8; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const uint8_t * q = x[i].qs + step; + const int8_t * s = x[i].scales; + const float * y = yy + i*QK_K + step; + const float d = x[i].d; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + const uint8_t h = x[i].qh[in+j] >> im; + sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) + + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) + + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16)) + + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16)); + } + tmp += sum; + } +#endif + // sum up partial sums and write back result __syncthreads(); #pragma unroll @@ -803,7 +1020,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } - if (tid == 0) { + if (threadIdx.x == 0) { dst[row] = tmp; } } @@ -820,6 +1037,8 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float const block_q6_K * x = (const block_q6_K *)vx + ib0; +#if QK_K == 256 + const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 @@ -874,6 +1093,37 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float } +#else + + const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3 + + const int step = tid * K_QUANTS_PER_ITERATION; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + + const float * y = yy + i * QK_K + step; + const uint8_t * ql = x[i].ql + step; + const uint8_t * qh = x[i].qh + step; + const int8_t * s = x[i].scales; + + const float d = x[i+0].d; + + float sum = 0; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); + } + tmp += sum; + + } + +#endif + // sum up partial sums and write back result __syncthreads(); #pragma unroll @@ -1252,12 +1502,20 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q2_K<<>>(vx, y); +#else + dequantize_block_q2_K<<>>(vx, y); +#endif } static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q3_K<<>>(vx, y); +#else + dequantize_block_q3_K<<>>(vx, y); +#endif } static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -1267,12 +1525,20 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q5_K<<>>(vx, y); +#else + dequantize_block_q5_K<<>>(vx, y); +#endif } static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; +#if QK_K == 256 dequantize_block_q6_K<<>>(vx, y); +#else + dequantize_block_q6_K<<>>(vx, y); +#endif } static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { diff --git a/ggml-metal.m b/ggml-metal.m index a7e104dc7..7551231b9 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -51,21 +51,21 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); - GGML_METAL_DECL_KERNEL(get_rows_q2_k); - GGML_METAL_DECL_KERNEL(get_rows_q3_k); - GGML_METAL_DECL_KERNEL(get_rows_q4_k); - GGML_METAL_DECL_KERNEL(get_rows_q5_k); - GGML_METAL_DECL_KERNEL(get_rows_q6_k); + GGML_METAL_DECL_KERNEL(get_rows_q2_K); + GGML_METAL_DECL_KERNEL(get_rows_q3_K); + GGML_METAL_DECL_KERNEL(get_rows_q4_K); + GGML_METAL_DECL_KERNEL(get_rows_q5_K); + GGML_METAL_DECL_KERNEL(get_rows_q6_K); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(cpy_f32_f16); @@ -132,7 +132,13 @@ struct ggml_metal_context * ggml_metal_init(void) { exit(1); } +#ifdef GGML_QKK_64 + MTLCompileOptions* options = [MTLCompileOptions new]; + options.preprocessorMacros = @{ @"QK_K" : @(64) }; + ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; +#else ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error]; +#endif if (error) { fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]); exit(1); @@ -159,21 +165,21 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); - GGML_METAL_ADD_KERNEL(get_rows_q2_k); - GGML_METAL_ADD_KERNEL(get_rows_q3_k); - GGML_METAL_ADD_KERNEL(get_rows_q4_k); - GGML_METAL_ADD_KERNEL(get_rows_q5_k); - GGML_METAL_ADD_KERNEL(get_rows_q6_k); + GGML_METAL_ADD_KERNEL(get_rows_q2_K); + GGML_METAL_ADD_KERNEL(get_rows_q3_K); + GGML_METAL_ADD_KERNEL(get_rows_q4_K); + GGML_METAL_ADD_KERNEL(get_rows_q5_K); + GGML_METAL_ADD_KERNEL(get_rows_q6_K); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f16); @@ -662,7 +668,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; } break; case GGML_TYPE_Q3_K: { @@ -671,7 +677,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; } break; case GGML_TYPE_Q4_K: { @@ -680,7 +686,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; } break; case GGML_TYPE_Q5_K: { @@ -689,7 +695,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32]; } break; case GGML_TYPE_Q6_K: { @@ -698,7 +704,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32]; } break; default: { @@ -750,11 +756,11 @@ void ggml_metal_graph_compute( case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index d1e49222d..e62fe6842 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32( } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { - for (uint i = 16; i < nth; i += 16) sum[0] += sum[i]; + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[r1*ne0 + r0] = sum[0]; } } @@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32( } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + for (uint i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[r1*ne0 + r0] = sum[0]; } } @@ -775,47 +775,76 @@ kernel void kernel_cpy_f32_f32( //============================================ k-quants ====================================================== +#ifndef QK_K #define QK_K 256 +#else +static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64"); +#endif + +#if QK_K == 256 +#define K_SCALE_SIZE 12 +#else +#define K_SCALE_SIZE 4 +#endif typedef struct { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits uint8_t qs[QK_K/4]; // quants half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins -} block_q2_k; +} block_q2_K; // 84 bytes / block typedef struct { uint8_t hmask[QK_K/8]; // quants - high bit uint8_t qs[QK_K/4]; // quants - low 2 bits - uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits - half d; // super-block scale -} block_q3_k; -// 110 bytes / block +#if QK_K == 64 + uint8_t scales[2]; +#else + uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits +#endif + half d; // super-block scale +} block_q3_K; +#if QK_K == 64 +typedef struct { + half d[2]; // super-block scales/mins + uint8_t scales[2]; + uint8_t qs[QK_K/2]; // 4-bit quants +} block_q4_K; +#else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_k; -// 144 bytes / block +} block_q4_K; +#endif +#if QK_K == 64 +typedef struct { + half d; // super-block scales/mins + int8_t scales[QK_K/16]; // 8-bit block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +#else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_k; +} block_q5_K; // 176 bytes / block +#endif typedef struct { uint8_t ql[QK_K/2]; // quants, lower 4 bits uint8_t qh[QK_K/4]; // quants, upper 2 bits int8_t scales[QK_K/16]; // scales, quantized with 8 bits half d; // super-block scale -} block_q6_k; +} block_q6_K; // 210 bytes / block static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { @@ -836,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { //========================================== dequantization ============================= -static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) { +static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -847,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i device const uint8_t * q = x[i].qs; +#if QK_K == 256 int is = 0; float dl, ml; for (int n = 0; n < QK_K; n += 128) { @@ -865,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i } q += 32; } +#else + float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4); + float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4); + float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4); + float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4); + for (int l = 0; l < 16; ++l) { + y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1; + y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2; + y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3; + y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4; + } + y += QK_K; +#endif } } -static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) { +static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; +#if QK_K == 256 + const uint16_t kmask1 = 0x0303; const uint16_t kmask2 = 0x0f0f; @@ -918,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i } q += 32; } - } +#else + for (int i = 0; i < nb; i++) { + + const float d_all = (float)(x[i].d); + + device const uint8_t * q = x[i].qs; + device const uint8_t * hm = x[i].hmask; + + const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); + const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); + const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); + const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); + + for (int l = 0; l < 8; ++l) { + uint8_t h = hm[l]; + y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4)); + y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4)); + y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4)); + y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4)); + y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4)); + y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4)); + y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4)); + y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4)); + } + y += QK_K; + } +#endif } -static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) { +static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; - for (int i = 0; i < nb; i++) { + device const uint8_t * q = x[i].qs; + +#if QK_K == 256 const float d = x[i].d; const float min = x[i].dmin; - device const uint8_t * q = x[i].qs; device const uint8_t * scales = x[i].scales; int is = 0; @@ -945,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; q += 32; is += 2; } +#else + device const uint8_t * s = x[i].scales; + device const half2 * dh = (device const half2 *)x[i].d; + const float2 d = (float2)dh[0]; + const float d1 = d[0] * (s[0] & 0xF); + const float d2 = d[0] * (s[1] & 0xF); + const float m1 = d[1] * (s[0] >> 4); + const float m2 = d[1] * (s[1] >> 4); + for (int l = 0; l < 32; ++l) { + y[l+ 0] = d1 * (q[l] & 0xF) - m1; + y[l+32] = d2 * (q[l] >> 4) - m2; + } + y += QK_K; +#endif } } -static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) { +static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; +#if QK_K == 256 for (int i = 0; i < nb; i++) { const float d = (float)(x[i].d); @@ -973,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i u1 <<= 2; u2 <<= 2; } } +#else + for (int i = 0; i < nb; i++) { + + const float d = (float)x[i].d; + + device const uint8_t * ql = x[i].qs; + device const uint8_t * qh = x[i].qh; + device const int8_t * sc = x[i].scales; + + for (int l = 0; l < 8; ++l) { + y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); + y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); + y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); + y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); + y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); + y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); + y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); + y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); + } + y += QK_K; + } +#endif } -static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) { +static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -988,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i const float d = x[i].d; +#if QK_K == 256 for (int n = 0; n < QK_K; n += 128) { for (int l = 0; l < 32; ++l) { int is = l/16; @@ -1005,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i qh += 32; sc += 8; } +#else + for (int l = 0; l < 16; ++l) { + const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l+ 0] = d * sc[0] * q1; + y[l+16] = d * sc[1] * q2; + y[l+32] = d * sc[2] * q3; + y[l+48] = d * sc[3] * q4; + } + y += 64; +#endif } } -kernel void kernel_get_rows_q2_k( +kernel void kernel_get_rows_q2_K( device const void * src0, device const int * src1, device float * dst, @@ -1019,12 +1142,12 @@ kernel void kernel_get_rows_q2_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q2_k( - (device const block_q2_k *) ((device char *) src0 + r*nb01), + dequantize_row_q2_K( + (device const block_q2_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q3_k( +kernel void kernel_get_rows_q3_K( device const void * src0, device const int * src1, device float * dst, @@ -1035,12 +1158,12 @@ kernel void kernel_get_rows_q3_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q3_k( - (device const block_q3_k *) ((device char *) src0 + r*nb01), + dequantize_row_q3_K( + (device const block_q3_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q4_k( +kernel void kernel_get_rows_q4_K( device const void * src0, device const int * src1, device float * dst, @@ -1051,12 +1174,12 @@ kernel void kernel_get_rows_q4_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q4_k( - (device const block_q4_k *) ((device char *) src0 + r*nb01), + dequantize_row_q4_K( + (device const block_q4_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q5_k( +kernel void kernel_get_rows_q5_K( device const void * src0, device const int * src1, device float * dst, @@ -1067,12 +1190,12 @@ kernel void kernel_get_rows_q5_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q5_k( - (device const block_q5_k *) ((device char *) src0 + r*nb01), + dequantize_row_q5_K( + (device const block_q5_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q6_k( +kernel void kernel_get_rows_q6_K( device const void * src0, device const int * src1, device float * dst, @@ -1083,14 +1206,14 @@ kernel void kernel_get_rows_q6_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q6_k( - (device const block_q6_k *) ((device char *) src0 + r*nb01), + dequantize_row_q6_K( + (device const block_q6_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } //====================================== dot products ========================= -kernel void kernel_mul_mat_q2_k_f32( +kernel void kernel_mul_mat_q2_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1107,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb; + device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf = 0; + +#if QK_K == 256 const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 const int ir = tid%4; // 0...3 @@ -1125,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32( const int y_offset = 64*il + n*ir; const int q_offset = 32*ip + n*ir; - sum[ith] = 0.0f; - - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * q = x[i].qs + q_offset; @@ -1140,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32( device const float * y = yy + i*QK_K + y_offset; - //float4 s = {0.f, 0.f, 0.f, 0.f}; float2 s = {0.f, 0.f}; float smin = 0; for (int l = 0; l < n; ++l) { @@ -1155,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32( sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin; } +#else + const int il = 4 * tpitg.x; + + uint32_t aux[2]; + thread const uint8_t * d = (thread const uint8_t *)aux; + thread const uint8_t * m = (thread const uint8_t *)aux + 4; + + for (int i = tpitg.y; i < nb; i += tptg.y) { + + device const uint8_t * q = x[i].qs + il; + device const float * y = yy + i*QK_K + il; + + const float dall = (float)x[i].d; + const float dmin = (float)x[i].dmin; + + device const uint32_t * a = (device const uint32_t *)x[i].scales; + aux[0] = a[0] & 0x0f0f0f0f; + aux[1] = (a[0] >> 4) & 0x0f0f0f0f; + + for (int l = 0; l < 4; ++l) { + sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0]) + + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1]) + + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2]) + + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]); + } + } +#endif + sum[ith] = sumf; - //int mask1 = (ith%4 == 0); - //int mask2 = (ith%16 == 0); - - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i]; - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i]; - //threadgroup_barrier(mem_flags::mem_threadgroup); - //if (ith == 0) { - // for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - // dst[r1*ne0 + r0] = sum[0]; - //} - // // Accumulate the sum from all threads in the threadgroup - // This version is slightly faster than the commented out one below, - // which I copy-pasted from ggerganov's q4_0 dot product for metal. // threadgroup_barrier(mem_flags::mem_threadgroup); if (ith%4 == 0) { @@ -1190,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32( } } -kernel void kernel_mul_mat_q3_k_f32( +kernel void kernel_mul_mat_q3_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1203,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32( uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - - const uint8_t m3 = 3; - const int8_t m4 = 4; - const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb; + device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; +#if QK_K == 256 + + const uint8_t m3 = 3; + const int8_t m4 = 4; + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + const int tid = tpitg.y; // expecting 16 const int ip = tid/8; // 0 or 1 const int il = tid/2 - 4*ip; // 0...3 @@ -1273,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32( //sum[ith] = sumf; sum[ith] = sumf1 - 32.f*sumf2; +#else + const int il = 4 * tpitg.x; // 0, 4, 8, 12 + const int im = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 + + float sumf = 0; + + for (int i = tpitg.y; i < nb; i += tptg.y) { + + const float d_all = (float)(x[i].d); + + device const uint8_t * q = x[i].qs + il; + device const uint8_t * h = x[i].hmask + in; + device const float * y = yy + i * QK_K + il; + + const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); + const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); + const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); + const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); + + for (int l = 0; l < 4; ++l) { + const uint8_t hm = h[l] >> im; + sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4)) + + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4)) + + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4)) + + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4)); + } + + } + + sum[ith] = sumf; + +#endif // // Accumulate the sum from all threads in the threadgroup @@ -1293,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32( } -kernel void kernel_mul_mat_q4_k_f32( +kernel void kernel_mul_mat_q4_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1305,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32( uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; - const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; + device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + + float sumf = 0; + +#if QK_K == 256 + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 const int ir = tid - 4*il;// 0...3 @@ -1332,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32( const int q_offset = 32*im + l0; const int y_offset = 64*im + l0; - sum[ith] = 0.0f; - uchar2 sc1, sc2, sc3, sc4; - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * q1 = (x + i)->qs + q_offset; @@ -1365,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32( sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; } +#else + uint16_t aux16[2]; + thread const uint8_t * scales = (thread const uint8_t *)aux16; + + const int il = 4*tpitg.x; + + for (int i = tpitg.y; i < nb; i += tptg.y) { + + device const uint8_t * q = x[i].qs + il; + device const float * y = yy + i * QK_K + il; + + const float d = (float)x[i].d[0]; + const float m = (float)x[i].d[1]; + + device const uint16_t * a = (device const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + for (int l = 0; l < 4; ++l) { + sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16]) + + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]); + } + } +#endif sum[ith] = sumf; @@ -1401,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32( //} } -kernel void kernel_mul_mat_q5_k_f32( +kernel void kernel_mul_mat_q5_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1413,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32( uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb; + device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf = 0; + +#if QK_K == 256 + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 const int ir = tid - 4*il;// 0...3 @@ -1447,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32( uchar2 sc1, sc2, sc3, sc4; - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * q1 = (x + i)->qs + q_offset; @@ -1479,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32( sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; } +#else + const int il = 4 * tpitg.x; // 0, 4, 8, 12 + const int im = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 + + for (int i = tpitg.y; i < nb; i += tptg.y) { + + const float d = (float)x[i].d; + device const uint8_t * q = x[i].qs + il; + device const uint8_t * h = x[i].qh + in; + device const int8_t * s = x[i].scales; + device const float * y = yy + i*QK_K + il; + + for (int l = 0; l < 4; ++l) { + const uint8_t hl = h[l] >> im; + sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16)) + + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16)) + + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16)) + + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16)); + } + } +#endif sum[ith] = sumf; // @@ -1500,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32( } -kernel void kernel_mul_mat_q6_k_f32( +kernel void kernel_mul_mat_q6_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1522,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb; + device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf = 0; + +#if QK_K == 256 // Note: we absolutely assume that tptg.y = 16 and QK_K = 256! const int iqs = 16 * tpitg.y; const int ip = iqs / 128; // 0 or 1 @@ -1540,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32( const int q_offset_l = 64*ip + l0; const int q_offset_h = 32*ip + l0; - float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { device const uint8_t * ql = x[i].ql + q_offset_l; @@ -1562,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32( sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); } +#else + const int il = 4*tpitg.x; // 0, 4, 8, 12 + + for (int i = tpitg.y; i < nb; i += tptg.y) { + device const float * y = yy + i * QK_K + il; + device const uint8_t * ql = x[i].ql + il; + device const uint8_t * qh = x[i].qh + il; + device const int8_t * s = x[i].scales; + + const float d = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 4; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); + sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); + } + +#endif sum[ith] = sumf; diff --git a/k_quants.c b/k_quants.c index a48c82171..46dd884b0 100644 --- a/k_quants.c +++ b/k_quants.c @@ -261,6 +261,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t return scale; } +#if QK_K == 256 static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { if (j < 4) { *d = q[j] & 63; *m = q[j + 4] & 63; @@ -269,6 +270,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); } } +#endif //========================- 2-bit (de)-quantization @@ -330,11 +332,17 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict } } +#if QK_K == 256 for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); } } +#else + for (int l = 0; l < 16; ++l) { + y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); + } +#endif x += QK_K; @@ -352,6 +360,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int const uint8_t * q = x[i].qs; +#if QK_K == 256 int is = 0; float dl, ml; for (int n = 0; n < QK_K; n += 128) { @@ -370,7 +379,19 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int } q += 32; } - +#else + float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4); + float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4); + float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4); + float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4); + for (int l = 0; l < 16; ++l) { + y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1; + y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2; + y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3; + y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4; + } + y += QK_K; +#endif } } @@ -412,6 +433,7 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict } } +#if QK_K == 256 memset(y[i].scales, 0, 12); if (max_scale) { float iscale = -32.f/max_scale; @@ -445,9 +467,39 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict L[16*j + ii] = l + 4; } } +#else + if (max_scale) { + float iscale = -8.f/max_scale; + for (int j = 0; j < QK_K/16; j+=2) { + int l1 = nearest_int(iscale*scales[j]); + l1 = 8 + MAX(-8, MIN(7, l1)); + int l2 = nearest_int(iscale*scales[j+1]); + l2 = 8 + MAX(-8, MIN(7, l2)); + y[i].scales[j/2] = l1 | (l2 << 4); + } + y[i].d = ggml_fp32_to_fp16(1/iscale); + } else { + for (int j = 0; j < QK_K/16; j+=2) { + y[i].scales[j/2] = 0; + } + y[i].d = ggml_fp32_to_fp16(0.f); + } + for (int j = 0; j < QK_K/16; ++j) { + int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4; + float d = ggml_fp16_to_fp32(y[i].d) * (s - 8); + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } +#endif memset(y[i].hmask, 0, QK_K/8); - // We put the high-bit for the 1st 32 quants into bit 0, the next 32 into bit 1, etc. + // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc. int m = 0; uint8_t hm = 1; for (int j = 0; j < QK_K; ++j) { @@ -459,19 +511,25 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict m = 0; hm <<= 1; } } +#if QK_K == 256 for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); } } +#else + for (int l = 0; l < 16; ++l) { + y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); + } +#endif x += QK_K; } } +#if QK_K == 256 void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { assert(k % QK_K == 0); - assert(QK_K == 256); const int nb = k / QK_K; const uint32_t kmask1 = 0x03030303; @@ -519,6 +577,39 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int } } +#else +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + assert(QK_K == 64); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d_all = ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + + const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); + const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); + const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); + const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); + + for (int l=0; l<8; ++l) { + uint8_t h = hm[l]; + y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4)); + y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4)); + y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4)); + y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4)); + y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4)); + y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4)); + y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4)); + y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4)); + } + y += QK_K; + } +} +#endif void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) { quantize_row_q3_K_reference(x, vy, k); @@ -563,6 +654,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict } } +#if QK_K == 256 float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; float inv_min = max_min > 0 ? 63.f/max_min : 0.f; for (int j = 0; j < QK_K/32; ++j) { @@ -594,9 +686,43 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict L[32*j + ii] = l; } } +#else + const float s_factor = 15.f; + float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f; + float inv_min = max_min > 0 ? s_factor/max_min : 0.f; + int d1 = nearest_int(inv_scale*scales[0]); + int m1 = nearest_int(inv_min*mins[0]); + int d2 = nearest_int(inv_scale*scales[1]); + int m2 = nearest_int(inv_min*mins[1]); + y[i].scales[0] = d1 | (m1 << 4); + y[i].scales[1] = d2 | (m2 << 4); + y[i].d[0] = ggml_fp32_to_fp16(max_scale/s_factor); + y[i].d[1] = ggml_fp32_to_fp16(max_min/s_factor); + + float sumlx = 0; + int suml2 = 0; + for (int j = 0; j < QK_K/32; ++j) { + const uint8_t sd = y[i].scales[j] & 0xF; + const uint8_t sm = y[i].scales[j] >> 4; + const float d = ggml_fp16_to_fp32(y[i].d[0]) * sd; + if (!d) continue; + const float m = ggml_fp16_to_fp32(y[i].d[1]) * sm; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + m)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + sumlx += (x[32*j + ii] + m)*l*sd; + suml2 += l*l*sd*sd; + } + } + if (suml2) { + y[i].d[0] = ggml_fp32_to_fp16(sumlx/suml2); + } +#endif uint8_t * q = y[i].qs; for (int j = 0; j < QK_K; j += 64) { - for (int l = 0; l < 32; ++l) *q++ = L[j + l] | (L[j + l + 32] << 4); + for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4); + q += 32; } x += QK_K; @@ -610,11 +736,13 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int for (int i = 0; i < nb; i++) { - const float d = ggml_fp16_to_fp32(x[i].d); - const float min = ggml_fp16_to_fp32(x[i].dmin); - const uint8_t * q = x[i].qs; +#if QK_K == 256 + + const float d = ggml_fp16_to_fp32(x[i].d); + const float min = ggml_fp16_to_fp32(x[i].dmin); + int is = 0; uint8_t sc, m; for (int j = 0; j < QK_K; j += 64) { @@ -626,6 +754,17 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; q += 32; is += 2; } +#else + const float dall = ggml_fp16_to_fp32(x[i].d[0]); + const float mall = ggml_fp16_to_fp32(x[i].d[1]); + const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4); + const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4); + for (int l = 0; l < 32; ++l) { + y[l+ 0] = d1 * (q[l] & 0xF) - m1; + y[l+32] = d2 * (q[l] >> 4) - m2; + } + y += QK_K; +#endif } } @@ -653,12 +792,19 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict assert(k % QK_K == 0); const int nb = k / QK_K; +#if QK_K == 256 uint8_t L[QK_K]; float mins[QK_K/32]; float scales[QK_K/32]; +#else + int8_t L[QK_K]; + float scales[QK_K/16]; +#endif for (int i = 0; i < nb; i++) { +#if QK_K == 256 + float max_scale = 0; // as we are deducting the min, scales are always positive float max_min = 0; for (int j = 0; j < QK_K/32; ++j) { @@ -725,6 +871,52 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict m1 <<= 2; m2 <<= 2; ql += 32; } +#else + float max_scale = 0, amax = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1); + float abs_scale = fabsf(scales[j]); + if (abs_scale > amax) { + amax = abs_scale; + max_scale = scales[j]; + } + } + + float iscale = -128.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*scales[j]); + y[i].scales[j] = MAX(-128, MIN(127, l)); + } + y[i].d = ggml_fp32_to_fp16(1/iscale); + + for (int j = 0; j < QK_K/16; ++j) { + const float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j]; + if (!d) continue; + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-16, MIN(15, l)); + L[16*j + ii] = l + 16; + } + } + + uint8_t * restrict qh = y[i].qh; + uint8_t * restrict ql = y[i].qs; + memset(qh, 0, QK_K/8); + + for (int j = 0; j < 32; ++j) { + int jm = j%8; + int is = j/8; + int l1 = L[j]; + if (l1 > 15) { + l1 -= 16; qh[jm] |= (1 << is); + } + int l2 = L[j + 32]; + if (l2 > 15) { + l2 -= 16; qh[jm] |= (1 << (4 + is)); + } + ql[j] = l1 | (l2 << 4); + } +#endif x += QK_K; @@ -737,12 +929,14 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int for (int i = 0; i < nb; i++) { - const float d = ggml_fp16_to_fp32(x[i].d); - const float min = ggml_fp16_to_fp32(x[i].dmin); - const uint8_t * ql = x[i].qs; const uint8_t * qh = x[i].qh; +#if QK_K == 256 + + const float d = ggml_fp16_to_fp32(x[i].d); + const float min = ggml_fp16_to_fp32(x[i].dmin); + int is = 0; uint8_t sc, m; uint8_t u1 = 1, u2 = 2; @@ -756,6 +950,21 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int ql += 32; is += 2; u1 <<= 2; u2 <<= 2; } +#else + float d = ggml_fp16_to_fp32(x[i].d); + const int8_t * restrict s = x[i].scales; + for (int l = 0; l < 8; ++l) { + y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); + y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); + y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); + y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); + y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); + y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); + y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); + y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); + } + y += QK_K; +#endif } } @@ -823,6 +1032,7 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict uint8_t * restrict ql = y[i].ql; uint8_t * restrict qh = y[i].qh; +#if QK_K == 256 for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { const uint8_t q1 = L[j + l + 0] & 0xF; @@ -836,6 +1046,16 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict ql += 64; qh += 32; } +#else + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[l + 0] & 0xF; + const uint8_t q2 = L[l + 32] & 0xF; + ql[l] = q1 | (q2 << 4); + } + for (int l = 0; l < 16; ++l) { + qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6); + } +#endif x += QK_K; @@ -854,6 +1074,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int const uint8_t * restrict qh = x[i].qh; const int8_t * restrict sc = x[i].scales; +#if QK_K == 256 for (int n = 0; n < QK_K; n += 128) { for (int l = 0; l < 32; ++l) { int is = l/16; @@ -871,6 +1092,19 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int qh += 32; sc += 8; } +#else + for (int l = 0; l < 16; ++l) { + const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l+ 0] = d * sc[0] * q1; + y[l+16] = d * sc[1] * q2; + y[l+32] = d * sc[2] * q3; + y[l+48] = d * sc[3] * q4; + } + y += 64; +#endif } } @@ -1002,6 +1236,7 @@ static inline __m128i get_scale_shuffle(int i) { } #endif +#if QK_K == 256 void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const block_q2_K * restrict x = vx; @@ -1201,6 +1436,168 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri #endif } +#else + +void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + + const block_q2_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m3 = vdupq_n_u8(0x3); + const int32x4_t vzero = vdupq_n_s32(0); + + int8x16x4_t q2bytes; + + uint32_t aux32[2]; + const uint8_t * scales = (const uint8_t *)aux32; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const float dmin = -y[i].d * (float)x[i].dmin; + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + + aux32[0] = sc[0] & 0x0f0f0f0f; + aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; + + sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); + + int isum1 = 0, isum2 = 0; + + const uint8x16_t q2bits = vld1q_u8(q2); + + const int8x16x4_t q8bytes = vld1q_s8_x4(q8); + + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); + q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3)); + q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3)); + +#if defined(__ARM_FEATURE_DOTPROD) + isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0]; + isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1]; + isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2]; + isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3]; +#else + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum1 += vaddvq_s16(p1) * scales[0]; + isum2 += vaddvq_s16(p2) * scales[1]; + + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum1 += vaddvq_s16(p3) * scales[2]; + isum2 += vaddvq_s16(p4) * scales[3]; +#endif + sum += d * (isum1 + isum2); + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t ud, um; + const uint8_t * restrict db = (const uint8_t *)&ud; + const uint8_t * restrict mb = (const uint8_t *)&um; + + float summs = 0; + + // TODO: optimize this + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + ud = (sc[0] >> 0) & 0x0f0f0f0f; + um = (sc[0] >> 4) & 0x0f0f0f0f; + + int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; + summs += dmin * smin; + + const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); + const __m256i q2_0 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 2), q2bits), m3); + const __m256i q2_1 = _mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + + const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0)); + const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1)); + const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0)); + const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1)); + + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc); + } + + *s = hsum_float_8(acc) + summs; + +#else + + float sumf = 0; + + int isum[4]; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < QK_K/16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin); + + isum[0] = isum[1] = isum[2] = isum[3] = 0; + for (int l = 0; l < 16; ++l) { + isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3); + isum[1] += q8[l+16] * ((q2[l] >> 2) & 3); + isum[2] += q8[l+32] * ((q2[l] >> 4) & 3); + isum[3] += q8[l+48] * ((q2[l] >> 6) & 3); + } + for (int l = 0; l < 4; ++l) { + isum[l] *= (sc[l] & 0xF); + } + sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs; + } + *s = sumf; +#endif +} +#endif + +#if QK_K == 256 void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); @@ -1501,6 +1898,206 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri } +#else + +void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + const uint8x16_t m3b = vdupq_n_u8(0x3); + const uint8x16_t mh = vdupq_n_u8(4); + + int8x16x4_t q3bytes; + + uint16_t aux16[2]; + int8_t * scales = (int8_t *)aux16; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + uint8x16x4_t q3h; + + const uint8x8_t hbits = vld1_u8(x[i].hmask); + const uint8x16_t q3bits = vld1q_u8(x[i].qs); + const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs); + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + for (int j = 0; j < 4; ++j) scales[j] -= 8; + + int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + + const float d = y[i].d * (float)x[i].d; + + const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1)); + q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); + q3h.val[1] = vandq_u8(mh, htmp); + q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); + q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4)); + + q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); + q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); + q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); + q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; +#else + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3]; +#endif + + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i m1 = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + uint64_t aux64; + + uint16_t aux16[2]; + const int8_t * aux8 = (const int8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + const __m256i scale_0 = _mm256_set_m128i(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8)); + const __m256i scale_1 = _mm256_set_m128i(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8)); + + memcpy(&aux64, x[i].hmask, 8); + + const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); + __m256i q3h_0 = _mm256_set_m128i(_mm_srli_epi16(haux, 2), haux); + __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4); + q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2); + q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2); + + // load low 2 bits + const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + + // prepare low and high bits + const __m256i q3aux = _mm256_set_m128i(_mm_srli_epi16(q3bits, 2), q3bits); + const __m256i q3l_0 = _mm256_and_si256(q3aux, m3); + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3); + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + + // multiply with scales + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + p16_0 = _mm256_add_epi32(p16_0, p16_1); + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc); + + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + int32_t scales[4]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + int8_t * restrict a = aux8; + for (int l = 0; l < 8; ++l) { + a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4); + a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4); + a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4); + a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4); + a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4); + a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4); + a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4); + a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4); + } + + scales[0] = (x[i].scales[0] & 0xF) - 8; + scales[1] = (x[i].scales[0] >> 4) - 8; + scales[2] = (x[i].scales[1] & 0xF) - 8; + scales[3] = (x[i].scales[1] >> 4) - 8; + + memset(aux32, 0, 8*sizeof(int32_t)); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l]; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} +#endif + +#if QK_K == 256 void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); @@ -1614,9 +2211,6 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); const uint32_t uaux = utmp[1] & kmask1; @@ -1624,6 +2218,9 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri utmp[2] = uaux; utmp[0] &= kmask1; + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); @@ -1726,7 +2323,176 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri *s = sumf; #endif } +#else +void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + float sumf = 0; + + int8x16x2_t q4bytes; + int8x16x4_t q8bytes; + + float sum_mins = 0.f; + + uint16_t aux16[2]; + const uint8_t * restrict scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t * restrict a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]); + sum_mins += y[i].d * (float)x[i].d[1] * summi; + + const float d = y[i].d * (float)x[i].d[0]; + + const uint8x16x2_t q4bits = vld1q_u8_x2(q4); + +#ifdef __ARM_FEATURE_DOTPROD + q8bytes = vld1q_s8_x4(q8); + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + const int32_t sumi1 = vaddvq_s32(p1) * scales[0]; + + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); + const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; + +#else + q8bytes = vld1q_s8_x4(q8); + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0]; + + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3]))); + int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1]; + +#endif + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf - sum_mins; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = ggml_fp16_to_fp32(x[i].d[0]) * y[i].d; + const float m = ggml_fp16_to_fp32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + + const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc); + + const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc); + + } + + *s = hsum_float_8(acc) - summs; + +#else + + uint8_t aux8[QK_K]; + int16_t aux16[16]; + float sums [8]; + memset(sums, 0, 8*sizeof(float)); + + uint16_t s16[2]; + const uint8_t * restrict scales = (const uint8_t *)s16; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + uint8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; + for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; + + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + sumf -= y[i].d * ggml_fp16_to_fp32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d[0]); + + for (int j = 0; j < QK_K/32; ++j) { + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; + q8 += 16; a += 16; + const float dl = d * scales[j]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]); + } + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#endif + +#if QK_K == 256 void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); @@ -1840,18 +2606,23 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri for (int i = 0; i < nb; ++i) { - const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); - const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); - const uint8_t * restrict q5 = x[i].qs; const int8_t * restrict q8 = y[i].qs; +#if QK_K == 256 + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); + memcpy(utmp, x[i].scales, 12); utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); const uint32_t uaux = utmp[1] & kmask1; utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[2] = uaux; utmp[0] &= kmask1; +#else + // TODO + const float d = 0, dmin = 0; +#endif const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); @@ -1972,8 +2743,169 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri #endif } +#else + +void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q5_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int32x4_t mzero = vdupq_n_s32(0); + const uint8x16_t mh = vdupq_n_u8(16); + + int8x16x4_t q5bytes; + uint8x16x4_t q5h; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const int8_t * sc = x[i].scales; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const uint8x8_t qhbits = vld1_u8(qh); + + const uint8x16x2_t q5bits = vld1q_u8_x2(q5); + const int8x16x4_t q8bytes = vld1q_s8_x4(q8); + + const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); + q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); + q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2)); + q5h.val[2] = vbicq_u8(mh, htmp); + q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2)); + + q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0])); + q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1])); + q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2])); + q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); + int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); + int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); + int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); + + sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); + +#else + + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1); + + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3); + + sumf += d*sumi; +#endif + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); + + const __m256i scale_l = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); + const __m256i scale_h = _mm256_set_m128i(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); + + int64_t aux64; + memcpy(&aux64, x[i].qh, 8); + const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64); + const __m256i haux256 = _mm256_set_m128i(_mm_srli_epi16(haux128, 2), haux128); + + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4); + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0)); + const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1)); + const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0)); + const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1)); + + const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1)); + + acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc); + + } + + *s = hsum_float_8(acc); + +#else + uint8_t aux8[QK_K]; + int16_t aux16[16]; + float sums [8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + uint8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) { + a[l+ 0] = q4[l] & 0xF; + a[l+32] = q4[l] >> 4; + } + for (int is = 0; is < 8; ++is) { + uint8_t m = 1 << is; + for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16); + } + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + const int8_t * restrict sc = x[i].scales; + + for (int j = 0; j < QK_K/16; ++j) { + const float dl = d * sc[j]; + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]); + q8 += 16; a += 16; + } + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#endif + + +#if QK_K == 256 void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); @@ -2242,3 +3174,179 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri *s = sumf; #endif } + +#else + +void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q6_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + float sum = 0; + + const uint8x16_t m4b = vdupq_n_u8(0xF); + const int32x4_t vzero = vdupq_n_s32(0); + const int8x16_t m32s = vdupq_n_s8(32); + + const uint8x16_t mone = vdupq_n_u8(3); + + int8x16x4_t q6bytes; + uint8x16x4_t q6h; + + for (int i = 0; i < nb; ++i) { + + const float d_all = (float)x[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int32_t isum = 0; + + uint8x16_t qhbits = vld1q_u8(qh); + uint8x16x2_t q6bits = vld1q_u8_x2(q6); + int8x16x4_t q8bytes = vld1q_s8_x4(q8); + + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits, 2); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 4); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s); + q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; +#else + + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + + sum += isum * d_all * y[i].d; + + } + *s = sum; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + + __m256i sumi = _mm256_setzero_si256(); + + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_set_m128i(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4); + + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int l = 0; l < 16; ++l) { + a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#endif diff --git a/k_quants.h b/k_quants.h index 10a0baac7..6abe3d7b8 100644 --- a/k_quants.h +++ b/k_quants.h @@ -7,7 +7,13 @@ #include // Super-block size +#ifdef GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else #define QK_K 256 +#define K_SCALE_SIZE 12 +#endif // // Super-block quantization structures @@ -29,38 +35,67 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "w // weight is represented as x = a * q // 16 blocks of 16 elemenets each // Effectively 3.4375 bits per weight +#ifdef GGML_QKK_64 typedef struct { uint8_t hmask[QK_K/8]; // quants - high bit uint8_t qs[QK_K/4]; // quants - low 2 bits - uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + uint8_t scales[2]; ggml_fp16_t d; // super-block scale } block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding"); +static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding"); +#else +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[12]; // scales, quantized with 6 bits + ggml_fp16_t d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); +#endif // 4-bit quantization // 16 blocks of 32 elements each // weight is represented as x = a * q + b // Effectively 4.5 bits per weight +#ifdef GGML_QKK_64 +typedef struct { + ggml_fp16_t d[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); +#else typedef struct { ggml_fp16_t d; // super-block scale for quantized scales ggml_fp16_t dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; -static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); +static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); +#endif // 5-bit quantization // 16 blocks of 32 elements each // weight is represented as x = a * q + b // Effectively 5.5 bits per weight +#ifdef GGML_QKK_64 typedef struct { - ggml_fp16_t d; // super-block scale for quantized scales - ggml_fp16_t dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + ggml_fp16_t d; // super-block scale + int8_t scales[QK_K/16]; // 8-bit block scales uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; -static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); +#else +typedef struct { + ggml_fp16_t d; // super-block scale for quantized scales + ggml_fp16_t dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif // 6-bit quantization // weight is represented as x = a * q diff --git a/llama.cpp b/llama.cpp index ac22a48f8..c41c2a8a3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -21,9 +21,13 @@ #endif #ifdef GGML_USE_K_QUANTS #ifndef QK_K +#ifdef GGML_QKK_64 +#define QK_K 64 +#else #define QK_K 256 #endif #endif +#endif #include #include @@ -2470,6 +2474,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s std::vector workers; std::mutex mutex; + auto use_more_bits = [] (int i_layer, int num_layers) -> bool { + return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2; + }; + size_t idx = 0; for (llama_load_tensor & tensor : model_loader->tensors_map.tensors) { llama_buffer read_data; @@ -2524,15 +2532,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && - (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8 || - (i_attention_wv - n_attention_wv/8)%3 == 2)) new_type = GGML_TYPE_Q6_K; + use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K; + else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) && + (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; ++i_attention_wv; } else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && - (i_feed_forward_w2 < n_feed_forward_w2/8 || i_feed_forward_w2 >= 7*n_feed_forward_w2/8 || - (i_feed_forward_w2 - n_feed_forward_w2/8)%3 == 2)) new_type = GGML_TYPE_Q6_K; + use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; + //else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < n_feed_forward_w2/8) new_type = GGML_TYPE_Q6_K; ++i_feed_forward_w2; } else if (tensor.name.find("attention.wo.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;