From 7a32fcb3b29f4db8aed8a85dc58eb958fb118153 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 25 Apr 2023 23:40:51 +0300 Subject: [PATCH] ggml : add Q8_0 quantization format (rename the old one to Q8_1) (ARM NEON) (#1179) * ggml : add Q8_0 quantization format (rename the old one to Q8_1) * tests : fix test-quantize-fns * ggml : finalize Q8_0 implementation * ggml : use q4_0_q8_0 and q4_2_q8_0 * ggml : fix Q8_0 dot product bug (ARM) * ggml : Q8_0 unroll x2 * ggml : fix bug - using wrong block type * ggml : extend quantize_fns_t with "vec_dot_type" * ggml : fix Q8_0 to use 255 values out of 256 * ggml : fix assert using wrong QK4_2 instead of QK4_3 --- examples/quantize/quantize.cpp | 1 + ggml-cuda.cu | 28 +++ ggml-cuda.h | 1 + ggml.c | 407 +++++++++++++++++++++------------ ggml.h | 3 + llama.cpp | 4 + llama.h | 1 + tests/test-quantize-fns.cpp | 14 +- 8 files changed, 312 insertions(+), 147 deletions(-) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 5b4812c62..ad39a805d 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -16,6 +16,7 @@ int main(int argc, char ** argv) { fprintf(stderr, " type = %d - q4_1\n", LLAMA_FTYPE_MOSTLY_Q4_1); fprintf(stderr, " type = %d - q4_2\n", LLAMA_FTYPE_MOSTLY_Q4_2); fprintf(stderr, " type = %d - q4_3\n", LLAMA_FTYPE_MOSTLY_Q4_3); + fprintf(stderr, " type = %d - q8_0\n", LLAMA_FTYPE_MOSTLY_Q8_0); return 1; } diff --git a/ggml-cuda.cu b/ggml-cuda.cu index fa511c1dc..f104ed5ac 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -37,6 +37,13 @@ typedef struct { } block_q4_3; static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding"); +#define QK8_0 32 +typedef struct { + float d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); + static __global__ void dequantize_block_q4_0(const void * vx, float * y) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -131,6 +138,22 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) { } } +static __global__ void dequantize_block_q8_0(const void * vx, float * y) { + const block_q8_0 * x = (const block_q8_0 *) vx; + + const int i = blockIdx.x; + + const float d = x[i].d; + + const int8_t * pp = x[i].qs; + + for (int l = 0; l < QK8_0; l++) { + const int8_t vi = pp[l]; + + y[i*QK8_0 + l] = vi*d; + } +} + void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_0; dequantize_block_q4_0<<>>(vx, y); @@ -151,6 +174,11 @@ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t st dequantize_block_q4_3<<>>(vx, y); } +void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { + const int nb = k / QK8_0; + dequantize_block_q8_0<<>>(vx, y); +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 16 diff --git a/ggml-cuda.h b/ggml-cuda.h index 370bbc75f..4048ea491 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -35,6 +35,7 @@ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t st void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream); void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream); +void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream); #ifdef __cplusplus } diff --git a/ggml.c b/ggml.c index b4dd88db5..064510eda 100644 --- a/ggml.c +++ b/ggml.c @@ -676,12 +676,18 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong #define QK8_0 32 typedef struct { float d; // delta - float s0; // d * sum(qs[i]) low - float s1; // d * sum(qs[i]) high int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == 3*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +#define QK8_1 32 +typedef struct { + float d; // delta + float s0; // d * sum(qs[i]) low + float s1; // d * sum(qs[i]) high + int8_t qs[QK8_1]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); // reference implementation for deterministic creation of model files static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { @@ -1231,85 +1237,6 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r } } -static inline int nearest_int(float fval) { - assert(fval <= 4194303.f); - float val = fval + 12582912.f; - int i; memcpy(&i, &val, sizeof(int)); - return (i & 0x007fffff) - 0x00400000; -} - -static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates, - const float * restrict candidates, int8_t * restrict L) { - assert (nmin >= INT8_MIN); - assert (nmax <= INT8_MAX); - float amax = 0; - for (int i=0; i sumlxM2*suml2P) { - if (sumlxP2 > best*suml2P) { - best = sumlxP2/suml2P; bestScale = iscale; - } - } else { - if (sumlxM2 > best*suml2M) { - best = sumlxM2/suml2M; bestScale = -iscale; - } - } - } - float sumlx = 0; int suml2 = 0; - for (int i=0; id * (y0->s0 + y0->s1) + x1->d * (y1->s0 + y1->s1); - const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2519,6 +2508,12 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); @@ -2533,21 +2528,21 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs); + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls), v0_0hs, v1_0hs); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls), v0_1hs, v1_1hs); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); #else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs)); + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs)); + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); @@ -2559,7 +2554,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #endif } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8; + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2651,14 +2646,14 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #endif } -static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK8_0; +static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_1; - assert(n % QK8_0 == 0); + assert(n % QK8_1 == 0); assert(nb % 2 == 0); const block_q4_1 * restrict x = vx; - const block_q8_0 * restrict y = vy; + const block_q8_1 * restrict y = vy; // TODO: add AVX / WASM SIMD / etc #if defined(__ARM_NEON) @@ -2670,8 +2665,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * for (int i = 0; i < nb; i += 2) { const block_q4_1 * restrict x0 = &x[i + 0]; const block_q4_1 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + const block_q8_1 * restrict y0 = &y[i + 0]; + const block_q8_1 * restrict y1 = &y[i + 1]; summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1); @@ -2769,7 +2764,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const int8_t * restrict p1 = y[i].qs; // TODO: this is very slow .. - for (int j = 0; j < QK8_0/2; j++) { + for (int j = 0; j < QK8_1/2; j++) { const uint8_t v0 = p0[j]; const float f0 = d0*(v0 & 0xf) + m0; @@ -2942,15 +2937,15 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * #endif } -static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK8_0; +static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_1; - assert(n % QK8_0 == 0); + assert(n % QK8_1 == 0); assert(nb % 2 == 0); - assert(QK8_0 == 2*QK4_2); + assert(QK8_1 == 2*QK4_3); const block_q4_3 * restrict x = vx; - const block_q8_0 * restrict y = vy; + const block_q8_1 * restrict y = vy; #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); @@ -2963,7 +2958,7 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0]; const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_1 * restrict y0 = &y[i + 0]; summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0; summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1; @@ -3046,7 +3041,7 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * int sxy_0 = 0; int sxy_1 = 0; - for (int j = 0; j < QK8_0/4; j++) { + for (int j = 0; j < QK8_1/4; j++) { const uint8_t v0 = x0[j]; const uint8_t v1 = x1[j]; @@ -3059,8 +3054,8 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const int y0_0 = y0[2*j + 0]; const int y1_0 = y0[2*j + 1]; - const int y0_1 = y0[2*(j + QK8_0/4) + 0]; - const int y1_1 = y0[2*(j + QK8_0/4) + 1]; + const int y0_1 = y0[2*(j + QK8_1/4) + 0]; + const int y1_1 = y0[2*(j + QK8_1/4) + 1]; sxy_0 += x0_0*y0_0 + x1_0*y1_0; sxy_1 += x0_1*y0_1 + x1_1*y1_1; @@ -3072,6 +3067,91 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * #endif } +static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_0; + + assert(n % QK8_0 == 0); + assert(nb % 2 == 0); + assert(QK8_0 == QK8_0); + + const block_q8_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i += 2) { + const block_q8_0 * restrict x0 = &x[i + 0]; + const block_q8_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const int8x16_t x0_0 = vld1q_s8(x0->qs); + const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); + const int8x16_t x1_0 = vld1q_s8(x1->qs); + const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); + + // load y + const int8x16_t y0_0 = vld1q_s8(y0->qs); + const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); + const int8x16_t y1_0 = vld1q_s8(y1->qs); + const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), + vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d); + + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), + vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d); + +#else + const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); + const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); + const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1)); + const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); + + const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0)); + const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); + const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1)); + const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); + + const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); + const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); + const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); + const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + const int8_t * restrict x0 = x[i].qs; + const int8_t * restrict y0 = y[i].qs; + + int sumi = 0; + + for (int j = 0; j < QK8_0; j++) { + const int v0 = x0[j]; + const int v1 = y0[j]; + + sumi += v0*v1; + } + + sumf += (x[i].d*y[i].d)*sumi; + } + + *s = sumf; +#endif +} // compute GGML_VEC_DOT_UNROLL dot products at once // xs - x row stride in bytes @@ -3269,6 +3349,14 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { #endif } +inline static void ggml_vec_sum_ggf(const int n, ggml_float * s, const float * x) { + ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += (ggml_float)x[i]; + } + *s = sum; +} + inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #ifndef GGML_USE_ACCELERATE float max = -INFINITY; @@ -3322,11 +3410,12 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_2] = QK4_2, [GGML_TYPE_Q4_3] = QK4_3, [GGML_TYPE_Q8_0] = QK8_0, + [GGML_TYPE_Q8_1] = QK8_1, [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 11, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), @@ -3336,11 +3425,12 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_2] = sizeof(block_q4_2), [GGML_TYPE_Q4_3] = sizeof(block_q4_3), [GGML_TYPE_Q8_0] = sizeof(block_q8_0), + [GGML_TYPE_Q8_1] = sizeof(block_q8_1), [GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 11, "GGML_TYPE_SIZE is outdated"); static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { @@ -3351,11 +3441,12 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_2] = "q4_2", [GGML_TYPE_Q4_3] = "q4_3", [GGML_TYPE_Q8_0] = "q8_0", + [GGML_TYPE_Q8_1] = "q8_1", [GGML_TYPE_I8] = "i8", [GGML_TYPE_I16] = "i16", [GGML_TYPE_I32] = "i32", }; -static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated"); +static_assert(GGML_TYPE_COUNT == 11, "GGML_TYPE_NAME is outdated"); static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = false, @@ -3365,11 +3456,12 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_2] = true, [GGML_TYPE_Q4_3] = true, [GGML_TYPE_Q8_0] = true, + [GGML_TYPE_Q8_1] = true, [GGML_TYPE_I8] = false, [GGML_TYPE_I16] = false, [GGML_TYPE_I32] = false, }; -static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated"); +static_assert(GGML_TYPE_COUNT == 11, "GGML_IS_QUANTIZED is outdated"); static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -6581,6 +6673,7 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q8_0: { ggml_compute_forward_add_q_f32(params, src0, src1, dst); } break; @@ -6839,12 +6932,12 @@ static void ggml_compute_forward_sum_f32( const size_t nb03 = src0->nb[3]; ggml_float sum = 0; - float row_sum = 0; + ggml_float row_sum = 0; for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_f32(ne00, + ggml_vec_sum_ggf(ne00, &row_sum, (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); sum += row_sum; @@ -8008,6 +8101,7 @@ static void ggml_compute_forward_mul_mat_q_f32( const enum ggml_type type = src0->type; quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot; vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; + enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type; // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); @@ -8067,6 +8161,9 @@ static void ggml_compute_forward_mul_mat_q_f32( else if (type == GGML_TYPE_Q4_3) { dequantize_row_q_cuda = dequantize_row_q4_3_cuda; } + else if (type == GGML_TYPE_Q8_0) { + dequantize_row_q_cuda = dequantize_row_q8_0_cuda; + } else { GGML_ASSERT(false); } @@ -8141,7 +8238,7 @@ static void ggml_compute_forward_mul_mat_q_f32( if (params->type == GGML_TASK_INIT) { char * wdata = params->wdata; - const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0]; + const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { @@ -8172,7 +8269,7 @@ static void ggml_compute_forward_mul_mat_q_f32( const int ir1 = MIN(ir0 + dr, nr); void * wdata = params->wdata; - const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0]; + const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; for (int ir = ir0; ir < ir1; ++ir) { // src0 indices @@ -8223,6 +8320,7 @@ static void ggml_compute_forward_mul_mat( case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: { ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); } break; @@ -8452,6 +8550,7 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -10973,7 +11072,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } else #endif { - cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0]; + const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type; + cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q]; } } else { GGML_ASSERT(false); @@ -12242,6 +12342,27 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * return (n/QK4_3*sizeof(block_q4_3)); } +size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + for (int j = 0; j < n; j += k) { + block_q8_0 * restrict y = (block_q8_0 *)dst + j/QK8_0; + + quantize_row_q8_0_reference(src + j, y, k); + + for (int i = 0; i < nb; i++) { + for (int l = 0; l < QK8_0; ++l) { + const int8_t vi = y[i].qs[l]; + + hist[vi/16 + 8]++; + } + } + } + + return (n/QK8_0*sizeof(block_q8_0)); +} + size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) { size_t result = 0; switch (type) { @@ -12269,6 +12390,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i block_q4_3 * block = (block_q4_3*)dst + start / QK4_3; result = ggml_quantize_q4_3(src + start, block, n, n, hist); } break; + case GGML_TYPE_Q8_0: + { + GGML_ASSERT(start % QK8_0 == 0); + block_q8_0 * block = (block_q8_0*)dst + start / QK8_0; + result = ggml_quantize_q8_0(src + start, block, n, n, hist); + } break; default: assert(false); } diff --git a/ggml.h b/ggml.h index 275890781..8300a0c62 100644 --- a/ggml.h +++ b/ggml.h @@ -223,6 +223,7 @@ extern "C" { GGML_TYPE_Q4_2 = 4, GGML_TYPE_Q4_3 = 5, GGML_TYPE_Q8_0 = 6, + GGML_TYPE_Q8_1 = 7, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -832,6 +833,7 @@ extern "C" { GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); @@ -876,6 +878,7 @@ extern "C" { quantize_row_q_t quantize_row_q_reference; quantize_row_q_t quantize_row_q_dot; vec_dot_q_t vec_dot_q; + enum ggml_type vec_dot_type; } quantize_fns_t; quantize_fns_t ggml_internal_get_quantize_fn(size_t i); diff --git a/llama.cpp b/llama.cpp index 28d27916a..25203c9e9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -484,6 +484,7 @@ struct llama_file_loader { case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q8_0: break; default: { throw format("unrecognized tensor type %u\n", shard.type); @@ -558,6 +559,7 @@ struct llama_file_saver { case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_2: case GGML_TYPE_Q4_3: + case GGML_TYPE_Q8_0: break; default: LLAMA_ASSERT(false); } @@ -848,6 +850,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) { return "mostly Q4_1, some F16"; case LLAMA_FTYPE_MOSTLY_Q4_2: return "mostly Q4_2"; case LLAMA_FTYPE_MOSTLY_Q4_3: return "mostly Q4_3"; + case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; default: return "unknown, may not work"; } } @@ -1585,6 +1588,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q4_2: quantized_type = GGML_TYPE_Q4_2; break; case LLAMA_FTYPE_MOSTLY_Q4_3: quantized_type = GGML_TYPE_Q4_3; break; + case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; default: throw format("invalid output file type %d\n", ftype); }; diff --git a/llama.h b/llama.h index e9e3abea5..ab41798d8 100644 --- a/llama.h +++ b/llama.h @@ -74,6 +74,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors }; LLAMA_API struct llama_context_params llama_context_default_params(); diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index 7e091e8c4..a31a18827 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -36,7 +36,7 @@ float array_rmse(const float * a1, const float * a2, size_t n) { // Total quantization error on test data float total_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) { - std::vector tmp_q(test_size); + std::vector tmp_q(2*test_size); std::vector tmp_out(test_size); qfns.quantize_row_q(test_data, tmp_q.data(), test_size); @@ -46,7 +46,7 @@ float total_quantization_error(quantize_fns_t & qfns, size_t test_size, const fl // Total quantization error on test data float reference_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) { - std::vector tmp_q(test_size); + std::vector tmp_q(2*test_size); std::vector tmp_out(test_size); std::vector tmp_out_ref(test_size); @@ -69,10 +69,10 @@ float dot_product(const float * a1, const float * a2, size_t test_size) { // Total dot product error float dot_product_error(quantize_fns_t & qfns, size_t test_size, const float * test_data1, const float *test_data2) { - std::vector tmp_q1(test_size); - std::vector tmp_q2(test_size*2); + std::vector tmp_q1(2*test_size); + std::vector tmp_q2(2*test_size); - qfns.quantize_row_q(test_data1, tmp_q1.data(), test_size); + qfns.quantize_row_q (test_data1, tmp_q1.data(), test_size); qfns.quantize_row_q_dot(test_data2, tmp_q2.data(), test_size); float result = INFINITY; @@ -125,7 +125,7 @@ int main(int argc, char * argv[]) { failed = !(total_error < MAX_QUANTIZATION_TOTAL_ERROR); num_failed += failed; if (failed || verbose) { - printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error); + printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error); } const float reference_error = reference_quantization_error(qfns, test_size, test_data.data()); @@ -139,7 +139,7 @@ int main(int argc, char * argv[]) { failed = !(vec_dot_error < MAX_DOT_PRODUCT_ERROR); num_failed += failed; if (failed || verbose) { - printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error); + printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error); } } }