diff --git a/ggml.c b/ggml.c index c5a276898..ac9497e68 100644 --- a/ggml.c +++ b/ggml.c @@ -884,7 +884,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r static const int qk = QK4_0; assert(qk / 16 == 0); - assert(k % qk == 0); + assert( k % qk == 0); const int nb = k / qk; @@ -919,7 +919,7 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r const int qk = QK4_1; assert(qk / 16 == 0); - assert(k % qk == 0); + assert( k % qk == 0); const int nb = k / qk; @@ -952,48 +952,37 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k // reference implementation for deterministic creation of model files static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) { - assert(k % QK4_2 == 0); + static const int qk = QK4_2; - const int nb = k / QK4_2; + assert(qk / 16 == 0); + assert( k % qk == 0); + + const int nb = k / qk; for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max - float max = 0.0f; + float max = 0.0f; - for (int l = 0; l < QK4_2; l++) { - const float v = x[i*QK4_2 + l]; + for (int l = 0; l < qk; l++) { + const float v = x[i*qk + l]; if (amax < fabsf(v)) { amax = fabsf(v); - max = v; + max = v; } } - const float d = max / -8; - + const float d = max / -8; const float id = d ? 1.0f/d : 0.0f; y[i].d = GGML_FP32_TO_FP16(d); - for (int l = 0; l < QK4_2; l += 2) { - const float v0 = x[i*QK4_2 + l + 0]*id; - const float v1 = x[i*QK4_2 + l + 1]*id; + uint64_t qs[QK4_2 / 16] = {0}; - const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f)); - const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f)); - - assert(vi0 < 16); - assert(vi1 < 16); - - y[i].qs[l/2] = vi0 | (vi1 << 4); - } + nibbles_from_floats_64_0(qk, x + i*qk, id, y[i].qs, qs); } } -static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) { - assert(k % QK4_2 == 0); - - block_q4_2 * restrict y = vy; - +static void quantize_row_q4_2(const float * restrict x, void * restrict y, int k) { quantize_row_q4_2_reference(x, y, k); } @@ -1451,7 +1440,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict static const int qk = QK4_0; assert(qk / 16 == 0); - assert(k % qk == 0); + assert( k % qk == 0); const int nb = k / qk; @@ -1472,7 +1461,7 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict static const int qk = QK4_1; assert(qk / 16 == 0); - assert(k % qk == 0); + assert( k % qk == 0); const int nb = k / qk; @@ -1490,31 +1479,23 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict } } -static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) { - assert(k % QK4_2 == 0); - const int nb = k / QK4_2; +static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict y, int k) { + static const int qk = QK4_2; - const block_q4_2 * restrict x = vx; + assert(qk / 16 == 0); + assert( k % qk == 0); + + const int nb = k / qk; + + uint64_t qs[QK4_2 / 8]; for (int i = 0; i < nb; i++) { const float d = GGML_FP16_TO_FP32(x[i].d); - const uint8_t * restrict pp = x[i].qs; + const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs); - for (int l = 0; l < QK4_2; l += 2) { - const uint8_t vi = pp[l/2]; - - const int8_t vi0 = vi & 0x0F; - const int8_t vi1 = vi >> 4; - - const float v0 = (vi0 - 8)*d; - const float v1 = (vi1 - 8)*d; - - y[i*QK4_2 + l + 0] = v0; - y[i*QK4_2 + l + 1] = v1; - - assert(!isnan(y[i*QK4_2 + l + 0])); - assert(!isnan(y[i*QK4_2 + l + 1])); + for (int l = 0; l < qk; ++l) { + y[i*qk + l] = (qsp[l] - 8)*d; } } } @@ -1634,7 +1615,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_1, }, [GGML_TYPE_Q4_2] = { - .dequantize_row_q = dequantize_row_q4_2, + .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_2, .quantize_row_q = quantize_row_q4_2, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference, .quantize_row_q_dot = quantize_row_q8_0, @@ -2559,11 +2540,13 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * } static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK8_0; + const int qk = QK8_0; + const int nb = n / qk; - assert(n % QK8_0 == 0); + assert(n % qk == 0); assert(nb % 2 == 0); - assert(QK8_0 == 2*QK4_2); + + assert(qk == 2*QK4_2); const block_q4_2 * restrict x = vx; const block_q8_0 * restrict y = vy; @@ -2599,12 +2582,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - // interleave - const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs); - const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs); - const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs); - const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs); - // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); @@ -2613,22 +2590,22 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * #if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vaddq_f32( - vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)), - vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d); + vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)), + vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hs, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d); sumv1 = vmlaq_n_f32(sumv1, vaddq_f32( - vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)), - vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d); + vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)), + vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hs, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d); #else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h)); + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));