From ce411498f6c9b3209e8d1a43c1f975e9a15834e7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 21 Feb 2024 16:19:39 +0200 Subject: [PATCH] sync : llama.cpp (ggml/0) ggml-ci --- examples/common-ggml.cpp | 2 + ggml-cuda.cu | 98 +++++++++++++++- ggml-metal.m | 35 ++++++ ggml-metal.metal | 215 ++++++++++++++++++++++++++++++++++- ggml-quants.c | 234 ++++++++++++++++++++++++++++++++++++++- ggml-quants.h | 13 +++ ggml.c | 30 +++++ ggml.h | 2 + 8 files changed, 626 insertions(+), 3 deletions(-) diff --git a/examples/common-ggml.cpp b/examples/common-ggml.cpp index 32c97e8..f1c7f6e 100644 --- a/examples/common-ggml.cpp +++ b/examples/common-ggml.cpp @@ -66,6 +66,7 @@ bool ggml_common_quantize_0( case GGML_FTYPE_MOSTLY_IQ2_XS: case GGML_FTYPE_MOSTLY_IQ3_XXS: case GGML_FTYPE_MOSTLY_IQ1_S: + case GGML_FTYPE_MOSTLY_IQ4_NL: { fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); return false; @@ -199,6 +200,7 @@ bool ggml_common_quantize_0( case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: case GGML_TYPE_COUNT: { fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6caae56..e7c211d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -528,6 +528,15 @@ typedef struct { } block_iq1_s; static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding"); +#define QK4_NL 32 +#define QR4_NL 2 +#define QI4_NL (QK4_NL / (4*QR4_NL)) +typedef struct { + half d; + uint8_t qs[QK4_NL/2]; +} block_iq4_nl; +static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding"); + #define WARP_SIZE 32 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -1987,6 +1996,26 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_ } +static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + +template +static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int i = blockIdx.x; + const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL); + + const int tid = threadIdx.x; + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[ib].qs + 4*il; + const float d = (float)x[ib].d; + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf]; + y[j+16] = d * kvalues_iq4nl[q4[j] >> 4]; + } + +} static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { @@ -4732,6 +4761,56 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( #endif } +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics +static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values, + int & val1, int & val2) { + + uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32; + aux32 = q4 & 0x0f0f0f0f; + uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8); + uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8); + val1 = v1 | (v2 << 16); + aux32 = (q4 >> 4) & 0x0f0f0f0f; + v1 = values[q8[0]] | (values[q8[1]] << 8); + v2 = values[q8[2]] | (values[q8[3]] << 8); + val2 = v1 | (v2 << 16); +} +#endif + +static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + + const block_iq4_nl * bq = (const block_iq4_nl *) vbq; + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs; + const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs; + + const uint8_t * values = (const uint8_t *)kvalues_iq4nl; + + int v1, v2; + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) { + const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16); + get_int_from_table_16(aux, values, v1, v2); + sumi1 = __dp4a(v1, q8[l+0], sumi1); + sumi2 = __dp4a(v2, q8[l+4], sumi2); + } + +#else + const uint8_t * q4 = bq->qs + 4*iqs; + const int8_t * q8 = bq8_1->qs + 4*iqs; + + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < 4*VDR_Q4_0_Q8_1_MMVQ; ++l) { + sumi1 += q8[l+ 0] * kvalues_iq4nl[q4[l] & 0xf]; + sumi2 += q8[l+16] * kvalues_iq4nl[q4[l] >> 4]; + } +#endif + const float d = (float)bq->d * __low2float(bq8_1->ds); + return d * (sumi1 + sumi2); +} + template static __device__ __forceinline__ void mul_mat_q( @@ -6777,6 +6856,12 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, c dequantize_block_iq1_s<<>>(vx, y); } +template +static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq4_nl<<>>(vx, y); +} + template static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; @@ -6818,6 +6903,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq3_xxs_cuda; case GGML_TYPE_IQ1_S: return dequantize_row_iq1_s_cuda; + case GGML_TYPE_IQ4_NL: + return dequantize_row_iq4_nl_cuda; case GGML_TYPE_F32: return convert_unary_cuda; default: @@ -6855,6 +6942,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_xxs_cuda; case GGML_TYPE_IQ1_S: return dequantize_row_iq1_s_cuda; + case GGML_TYPE_IQ4_NL: + return dequantize_row_iq4_nl_cuda; case GGML_TYPE_F16: return convert_unary_cuda; default: @@ -8599,6 +8688,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array= CC_RDNA2 ? 128 : 64; default: GGML_ASSERT(false); @@ -8623,6 +8713,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array= CC_VOLTA ? 128 : 64; case GGML_TYPE_Q6_K: return 64; @@ -8724,6 +8815,10 @@ static void ggml_cuda_op_mul_mat_vec_q( mul_mat_vec_q_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; + case GGML_TYPE_IQ4_NL: + mul_mat_vec_q_cuda + (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); + break; default: GGML_ASSERT(false); break; @@ -11446,7 +11541,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons return false; } ggml_type a_type = a->type; - if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ1_S) { + if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || + a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL) { if (b->ne[1] == 1 && ggml_nrows(b) > 1) { return false; } diff --git a/ggml-metal.m b/ggml-metal.m index 956e323..0d4aa43 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -62,6 +62,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, GGML_METAL_KERNEL_TYPE_RMS_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, @@ -85,6 +86,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, @@ -104,6 +106,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, @@ -120,6 +123,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, @@ -136,6 +140,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_ROPE_F32, GGML_METAL_KERNEL_TYPE_ROPE_F16, GGML_METAL_KERNEL_TYPE_ALIBI_F32, @@ -448,6 +453,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); @@ -471,6 +477,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); @@ -490,6 +497,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); @@ -506,6 +514,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); @@ -522,6 +531,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true); @@ -1338,6 +1348,7 @@ static bool ggml_metal_graph_compute( case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } @@ -1478,6 +1489,12 @@ static bool ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; } break; + case GGML_TYPE_IQ4_NL: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); @@ -1525,6 +1542,11 @@ static bool ggml_metal_graph_compute( [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + else if (src0t == GGML_TYPE_IQ4_NL) { + const int mem_size = 32*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -1619,6 +1641,7 @@ static bool ggml_metal_graph_compute( case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); } @@ -1762,6 +1785,12 @@ static bool ggml_metal_graph_compute( nth1 = 16; pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; } break; + case GGML_TYPE_IQ4_NL: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; + } break; default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); @@ -1825,6 +1854,11 @@ static bool ggml_metal_graph_compute( [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } + else if (src2t == GGML_TYPE_IQ4_NL) { + const int mem_size = 32*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src2t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -1867,6 +1901,7 @@ static bool ggml_metal_graph_compute( case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index f0d77d4..c223a98 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2531,6 +2531,12 @@ typedef struct { uint8_t scales[QK_K/16]; } block_iq1_s; +// Non-linear quants +#define QK4_NL 32 +typedef struct { + half d; + uint8_t qs[QK4_NL/2]; +} block_iq4_nl; //====================================== dot products ========================= @@ -4384,7 +4390,6 @@ void kernel_mul_mv_iq1_s_f32_impl( const uint i13 = im/ne12; const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; @@ -4447,6 +4452,103 @@ void kernel_mul_mv_iq1_s_f32_impl( } } +constexpr constant static float kvalues_iq4nl_f[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +void kernel_mul_mv_iq4_nl_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup float * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK4_NL; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + const int ib_row = first_row * nb; + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0 or 1 + + shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK4_NL + it * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ib = ix; ib < nb; ib += 16) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2; ++row) { + + device const block_iq4_nl & xb = x[row*nb + ib]; + device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] | (q4[1] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[2] | (q4[3] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; + qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 16 * QK4_NL; + } + + for (int row = 0; row < 2; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + } + } +} + [[host_name("kernel_mul_mv_iq1_s_f32")]] kernel void kernel_mul_mv_iq1_s_f32( device const void * src0, @@ -4475,6 +4577,34 @@ kernel void kernel_mul_mv_iq1_s_f32( kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); } +[[host_name("kernel_mul_mv_iq4_nl_f32")]] +kernel void kernel_mul_mv_iq4_nl_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup float * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); +} //============================= templates and their specializations ============================= @@ -4838,6 +4968,21 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & } } +template +void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + template kernel void kernel_get_rows( device const void * src0, @@ -5381,6 +5526,7 @@ template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_r template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows; // // matrix-matrix multiplication @@ -5421,6 +5567,7 @@ template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication @@ -5473,6 +5620,7 @@ template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; // // matrix-vector multiplication @@ -6503,3 +6651,68 @@ kernel void kernel_mul_mv_id_iq1_s_f32( tiisg, sgitg); } + +[[host_name("kernel_mul_mv_id_iq4_nl_f32")]] +kernel void kernel_mul_mv_id_iq4_nl_f32( + device const char * ids, + device const char * src1, + device float * dst, + constant uint64_t & nbi1, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint64_t & nb1, + constant uint & r2, + constant uint & r3, + constant int & idx, + device const char * src00, + device const char * src01, + device const char * src02, + device const char * src03, + device const char * src04, + device const char * src05, + device const char * src06, + device const char * src07, + threadgroup float * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + + kernel_mul_mv_iq4_nl_f32_impl( + src0[id], + (device const float *) (src1 + bid*nb11), + dst + bid*ne0, + ne00, + ne01, + ne02, + ne10, + ne12, + ne0, + ne1, + r2, + r3, + shared_values, + tgpig, + tiisg, + sgitg); +} diff --git a/ggml-quants.c b/ggml-quants.c index 3319d2c..6336538 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3754,6 +3754,26 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in } } +static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + +void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int k) { + assert(k % QK4_NL == 0); + const int nb = k / QK4_NL; + + for (int i = 0; i < nb; i++) { + + const uint8_t * qs = x[i].qs; + + const float d = GGML_FP16_TO_FP32(x[i].d); + for (int j = 0; j < QK4_NL/2; ++j) { + y[j+ 0] = d * kvalues_iq4nl[qs[j] & 0xf]; + y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >> 4]; + } + y += QK4_NL; + qs += QK4_NL/2; + } +} + //===================================== Q8_K ============================================== void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { @@ -9148,7 +9168,6 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * #endif } -// TODO void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -9452,7 +9471,100 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const *s = sumf; #endif +} +void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * restrict x = vx; + const block_q8_0 * restrict y = vy; + + const int nb = n / QK4_NL; + +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_iq4nl); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + uint8x16x2_t q4bits; + int8x16x4_t q4b; + int8x16x4_t q8b; + int32x4_t prod_1, prod_2; + + float sumf = 0; + + for (int ib = 0; ib < nb; ib += 2) { + + q4bits.val[0] = vld1q_u8(x[ib+0].qs); + q4bits.val[1] = vld1q_u8(x[ib+1].qs); + q8b.val[0] = vld1q_s8(y[ib+0].qs); + q8b.val[1] = vld1q_s8(y[ib+0].qs + 16); + q8b.val[2] = vld1q_s8(y[ib+1].qs); + q8b.val[3] = vld1q_s8(y[ib+1].qs + 16); + + q4b.val[0] = vqtbl1q_s8(values, vandq_u8(q4bits.val[0], m4b)); + q4b.val[1] = vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = vqtbl1q_s8(values, vandq_u8(q4bits.val[1], m4b)); + q4b.val[3] = vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); + + prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); + prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); + + sumf += (float)x[ib+0].d * (float)y[ib+0].d * vaddvq_s32(prod_1) + (float)x[ib+1].d * (float)y[ib+1].d * vaddvq_s32(prod_2); + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + const __m256i mone = _mm256_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (int ib = 0; ib < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs); + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs); + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs); + const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); + const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); + const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); + accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)), + _mm256_cvtepi32_ps(p_1), accum1); + accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)), + _mm256_cvtepi32_ps(p_2), accum2); + + y += 2; + x += 2; + } + + *s = hsum_float_8(_mm256_add_ps(accum1, accum2)); + +#else + float sumf = 0; + for (int ib = 0; ib < nb; ++ib) { + const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +#endif } // ================================ IQ2 quantization ============================================= @@ -10729,3 +10841,123 @@ size_t quantize_iq1_s(const float * src, void * dst, int nrow, int n_per_row, in } return nrow * nblock * sizeof(block_iq1_s); } + +// ============================ 4-bit non-linear quants + +static inline int best_index_int8(int n, const int8_t * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +static void quantize_row_iq4_nl_impl(const int block_size, const float * GGML_RESTRICT x, + ggml_fp16_t * dh, uint8_t * q4, + float * weight, uint8_t * L, + const int8_t * values, + const float * quant_weights) { + + const int ntry = 7; + + float sigma2 = 0; + for (int j = 0; j < QK4_NL; ++j) sigma2 += x[j]*x[j]; + sigma2 *= 2.f/QK4_NL; + + const int nb = QK4_NL/block_size; + + memset(q4, 0, QK4_NL/2); + for (int ib = 0; ib < nb; ++ib) { + dh[ib] = GGML_FP32_TO_FP16(0.f); + const float * xb = x + ib*block_size; + if (quant_weights) { + const float * qw = quant_weights + ib*block_size; + for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + } else { + for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j]; + } + float amax = 0, max = 0; + for (int j = 0; j < block_size; ++j) { + float ax = fabsf(xb[j]); + if (ax > amax) { + amax = ax; max = xb[j]; + } + } + if (!amax) { + continue; + } + float d = -max/values[0]; + float id = 1/d; + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + int l = best_index_int8(16, values, al); + float q = values[l]; + float w = weight[j]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + float best_id = id; + d = sumqx/sumq2; + float best = d*sumqx; + for (int itry = -ntry; itry <= ntry; ++itry) { + id = (itry + values[0])/max; + sumqx = sumq2 = 0; + for (int j = 0; j < block_size; ++j) { + float al = id*xb[j]; + int l = best_index_int8(16, values, al); + float q = values[l]; + float w = weight[j]; + sumqx += w*q*xb[j]; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + d = sumqx/sumq2; best = d * sumqx; + best_id = id; + } + } + dh[ib] = GGML_FP32_TO_FP16(d); + for (int j = 0; j < block_size; ++j) { + L[ib*block_size + j] = best_index_int8(16, values, best_id*xb[j]); + } + } + for (int i = 0; i < QK4_NL/32; ++i) { + for (int j = 0; j < 16; ++j) { + q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4); + } + } +} + +size_t quantize_iq4_nl(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { + (void)hist; + GGML_ASSERT(n_per_row%QK4_NL == 0); + int nblock = n_per_row/QK4_NL; + char * qrow = (char *)dst; + uint8_t L[QK4_NL]; + float weight[32]; + for (int row = 0; row < nrow; ++row) { + block_iq4_nl * iq4 = (block_iq4_nl *)qrow; + for (int ibl = 0; ibl < nblock; ++ibl) { + const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL; + quantize_row_iq4_nl_impl(32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, weight, L, kvalues_iq4nl, qw); + } + src += n_per_row; + qrow += nblock*sizeof(block_iq4_nl); + } + return nrow * nblock * sizeof(block_iq4_nl); +} + +void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int k) { + assert(k % QK4_NL == 0); + block_iq4_nl * restrict y = vy; + quantize_row_iq4_nl_reference(x, y, k); +} + +void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int k) { + assert(k % QK4_NL == 0); + quantize_iq4_nl(x, y, 1, k, NULL, NULL); +} + diff --git a/ggml-quants.h b/ggml-quants.h index ad381cf..113623b 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -198,6 +198,14 @@ typedef struct { } block_iq1_s; static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding"); +// Non-linear quants +#define QK4_NL 32 +typedef struct { + ggml_fp16_t d; + uint8_t qs[QK4_NL/2]; +} block_iq4_nl; +static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding"); + #ifdef __cplusplus extern "C" { #endif @@ -217,6 +225,7 @@ void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGM void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int k); void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int k); void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int k); +void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); @@ -232,6 +241,7 @@ void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); +void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); // Dequantization void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); @@ -251,6 +261,7 @@ void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_ void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); +void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); // Dot product void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -268,6 +279,7 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); // // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") @@ -276,6 +288,7 @@ size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq3_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_iq1_s (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); +size_t quantize_iq4_nl (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix); diff --git a/ggml.c b/ggml.c index b90c7df..5b9fa74 100644 --- a/ggml.c +++ b/ggml.c @@ -690,6 +690,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_IQ4_NL] = { + .type_name = "iq4_nl", + .blck_size = QK4_NL, + .type_size = sizeof(block_iq4_nl), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_iq4_nl, + .from_float = quantize_row_iq4_nl, + .from_float_reference = (ggml_from_float_t)quantize_row_iq4_nl_reference, + .vec_dot = ggml_vec_dot_iq4_nl_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_Q8_K] = { .type_name = "q8_K", .blck_size = QK_K, @@ -2291,6 +2303,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break; case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break; case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break; + case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -7724,6 +7737,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: { ggml_compute_forward_add_q_f32(params, dst); } break; @@ -8002,6 +8016,7 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: { ggml_compute_forward_add1_q_f32(params, dst); } break; @@ -8125,6 +8140,7 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: default: { GGML_ASSERT(false); @@ -11022,6 +11038,7 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: { ggml_compute_forward_out_prod_q_f32(params, dst); } break; @@ -11209,6 +11226,7 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: default: { GGML_ASSERT(false); @@ -11410,6 +11428,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: { ggml_compute_forward_get_rows_q(params, dst); } break; @@ -12109,6 +12128,7 @@ static void ggml_compute_forward_alibi( case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -12191,6 +12211,7 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: case GGML_TYPE_Q8_K: case GGML_TYPE_I8: case GGML_TYPE_I16: @@ -19725,6 +19746,15 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i result = quantize_iq1_s(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); GGML_ASSERT(result == row_size * nrows); } break; + case GGML_TYPE_IQ4_NL: + { + GGML_ASSERT(start % QK4_NL == 0); + GGML_ASSERT(start % n_per_row == 0); + size_t start_row = start / n_per_row; + size_t row_size = ggml_row_size(type, n_per_row); + result = quantize_iq4_nl(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix); + GGML_ASSERT(result == row_size * nrows); + } break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); diff --git a/ggml.h b/ggml.h index 004d09c..bed7a36 100644 --- a/ggml.h +++ b/ggml.h @@ -355,6 +355,7 @@ extern "C" { GGML_TYPE_IQ2_XS = 17, GGML_TYPE_IQ3_XXS = 18, GGML_TYPE_IQ1_S = 19, + GGML_TYPE_IQ4_NL = 20, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -393,6 +394,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors + GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors }; // available tensor operations: