From 39bb8e7d19b429bc36cdda19bc448735c6af29ad Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 4 May 2023 23:31:35 +0300 Subject: [PATCH] ggml : 2x faster scalar implementations --- ggml.c | 132 +++++++++++++++++++++++++++++---------------------------- 1 file changed, 67 insertions(+), 65 deletions(-) diff --git a/ggml.c b/ggml.c index 516674c70..2cbc9b931 100644 --- a/ggml.c +++ b/ggml.c @@ -615,7 +615,8 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) #if __ARM_NEON -static inline const uint8_t * bytes_from_nibbles_64(const int qk, const uint8_t * qs, uint64_t * qd) { +// TODO: obosolete - will be removed +static inline const uint8_t * b4_from_nibbles_64(const int qk, const uint8_t * qs, uint64_t * qd) { memcpy(qd, qs, qk/2); for (int l = 0; l < qk/16; ++l) { @@ -875,14 +876,14 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r uint64_t qs[QK4_0 / 16] = {0}; for (int l = 0; l < qk/2; ++l) { - const float v0 = x[i*qk + 0 + l]*id; - const float v1 = x[i*qk + qk/2 + l]*id; + const float x0 = x[i*qk + 0 + l]*id; + const float x1 = x[i*qk + qk/2 + l]*id; - const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f)); - const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f)); + const uint64_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint64_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); - qs[l/8] |= vi0 << (8*(l & 7)); - qs[l/8] |= vi1 << (8*(l & 7) + 4); + qs[l/8] |= xi0 << (8*(l & 7)); + qs[l/8] |= xi1 << (8*(l & 7) + 4); } memcpy(y[i].qs, qs, qk/2); @@ -921,14 +922,14 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r uint64_t qs[QK4_1 / 16] = {0}; for (int l = 0; l < qk/2; ++l) { - const float v0 = (x[0 + l] - min)*id; - const float v1 = (x[qk/2 + l] - min)*id; + const float x0 = (x[0 + l] - min)*id; + const float x1 = (x[qk/2 + l] - min)*id; - const uint64_t vi0 = MIN(15, (int8_t)(v0 + 0.5f)); - const uint64_t vi1 = MIN(15, (int8_t)(v1 + 0.5f)); + const uint64_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint64_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); - qs[l/8] |= vi0 << (8*(l & 7)); - qs[l/8] |= vi1 << (8*(l & 7) + 4); + qs[l/8] |= xi0 << (8*(l & 7)); + qs[l/8] |= xi1 << (8*(l & 7) + 4); } memcpy(y[i].qs, qs, qk/2); @@ -968,14 +969,14 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r uint64_t qs[QK4_2 / 16] = {0}; for (int l = 0; l < qk/2; ++l) { - const float v0 = x[i*qk + 0 + l]*id; - const float v1 = x[i*qk + qk/2 + l]*id; + const float x0 = x[i*qk + 0 + l]*id; + const float x1 = x[i*qk + qk/2 + l]*id; - const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f)); - const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f)); + const uint64_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint64_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); - qs[l/8] |= vi0 << (8*(l & 7)); - qs[l/8] |= vi1 << (8*(l & 7) + 4); + qs[l/8] |= xi0 << (8*(l & 7)); + qs[l/8] |= xi1 << (8*(l & 7) + 4); } memcpy(y[i].qs, qs, qk/2); @@ -1015,18 +1016,18 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r uint64_t qs[QK5_0 / 16] = {0}; for (int l = 0; l < qk/2; ++l) { - const float v0 = x[i*qk + 0 + l]*id; - const float v1 = x[i*qk + qk/2 + l]*id; + const float x0 = x[i*qk + 0 + l]*id; + const float x1 = x[i*qk + qk/2 + l]*id; - const uint64_t vi0 = MIN(31, (int8_t)(v0 + 16.5f)); - const uint64_t vi1 = MIN(31, (int8_t)(v1 + 16.5f)); + const uint64_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint64_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); - qs[l/8] |= vi0 << (8*(l & 7)); - qs[l/8] |= vi1 << (8*(l & 7) + 4); + qs[l/8] |= xi0 << (8*(l & 7)); + qs[l/8] |= xi1 << (8*(l & 7) + 4); // get the 5-th bit and store it in qh at the right position - qh |= ((vi0 & 0x10) >> 4) << (l + 0); - qh |= ((vi1 & 0x10) >> 4) << (l + qk/2); + qh |= ((xi0 & 0x10) >> 4) << (l + 0); + qh |= ((xi1 & 0x10) >> 4) << (l + qk/2); } memcpy( y[i].qs, qs, qk/2); @@ -1447,15 +1448,15 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict const int nb = k / qk; - uint64_t qs[QK4_0 / 8]; - for (int i = 0; i < nb; i++) { const float d = x[i].d; - const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs); + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0xf) - 8; + const int x1 = (x[i].qs[j] >> 4) - 8; - for (int l = 0; l < qk; ++l) { - y[i*qk + l] = (qsp[l] - 8)*d; + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; } } } @@ -1468,21 +1469,22 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict const int nb = k / qk; - uint64_t qs[QK4_0 / 8]; - for (int i = 0; i < nb; i++) { const float d = x[i].d; const float m = x[i].m; - const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs); + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0xf); + const int x1 = (x[i].qs[j] >> 4); - for (int l = 0; l < qk; ++l) { - y[i*qk + l] = qsp[l]*d + m; + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; } } } static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict y, int k) { + // BORKEN !!! static const int qk = QK4_2; assert(qk / 16 == 0); @@ -1495,7 +1497,7 @@ static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict for (int i = 0; i < nb; i++) { const float d = GGML_FP16_TO_FP32(x[i].d); - const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs); + const uint8_t * qsp = b4_from_nibbles_64(qk, x[i].qs, qs); for (int l = 0; l < qk; ++l) { y[i*qk + l] = (qsp[l] - 8)*d; @@ -1511,20 +1513,21 @@ static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict const int nb = k / qk; - uint64_t qs[QK5_0 / 8]; - for (int i = 0; i < nb; i++) { const float d = GGML_FP16_TO_FP32(x[i].d); uint32_t qh; memcpy(&qh, x[i].qh, sizeof(qh)); - const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs); + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - for (int l = 0; l < qk; ++l) { - const uint8_t vh = ((qh & (1u << l)) >> l) << 4; + const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16; + const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; - y[i*qk + l] = ((qsp[l] | vh) - 16)*d; + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; } } } @@ -2388,17 +2391,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * // scalar float sumf = 0.0; - uint64_t qs[QK8_0 / 8]; - for (int i = 0; i < nb; i++) { - // unpack nibbles into bytes - const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs); - const int8_t * py = y[i].qs; + const int8_t * py = y[i].qs; int sumi = 0; - for (int j = 0; j < qk; ++j) { - sumi += (px[j] - 8) * py[j]; + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0xf) - 8; + const int v1 = (x[i].qs[j] >> 4) - 8; + + sumi += (v0 * py[j]) + (v1 * py[j + qk/2]); } sumf += (x[i].d*y[i].d)*sumi; @@ -2513,16 +2515,16 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * // scalar float sumf = 0.0; - uint64_t qs[QK8_1 / 8]; - for (int i = 0; i < nb; i++) { - const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs); - const int8_t * py = y[i].qs; + const int8_t * py = y[i].qs; int sumi = 0; - for (int j = 0; j < qk; ++j) { - sumi += px[j]*py[j]; + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0xf); + const int v1 = (x[i].qs[j] >> 4); + + sumi += (v0 * py[j]) + (v1 * py[j + qk/2]); } sumf += (x[i].d*y[i].d)*sumi + x[i].m*(y[i].s0 + y[i].s1); @@ -2847,22 +2849,22 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * // scalar float sumf = 0.0; - uint64_t qs[QK8_0 / 8]; - for (int i = 0; i < nb; i++) { - // unpack nibbles into bytes - const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs); - const int8_t * py = y[i].qs; + const int8_t * py = y[i].qs; uint32_t qh; memcpy(&qh, x[i].qh, sizeof(qh)); int sumi = 0; - for (int j = 0; j < qk; ++j) { - const int xh = ((qh & (1u << j)) >> j) << 4; + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - sumi += ((px[j] | xh) - 16)*py[j]; + const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16; + const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + + sumi += (x0 * py[j]) + (x1 * py[j + qk/2]); } sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi;