ggml : 2x faster scalar implementations

This commit is contained in:
Georgi Gerganov 2023-05-04 23:31:35 +03:00
parent 796f8ae261
commit 39bb8e7d19
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

132
ggml.c
View file

@ -615,7 +615,8 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
#if __ARM_NEON
static inline const uint8_t * bytes_from_nibbles_64(const int qk, const uint8_t * qs, uint64_t * qd) {
// TODO: obosolete - will be removed
static inline const uint8_t * b4_from_nibbles_64(const int qk, const uint8_t * qs, uint64_t * qd) {
memcpy(qd, qs, qk/2);
for (int l = 0; l < qk/16; ++l) {
@ -875,14 +876,14 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
uint64_t qs[QK4_0 / 16] = {0};
for (int l = 0; l < qk/2; ++l) {
const float v0 = x[i*qk + 0 + l]*id;
const float v1 = x[i*qk + qk/2 + l]*id;
const float x0 = x[i*qk + 0 + l]*id;
const float x1 = x[i*qk + qk/2 + l]*id;
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
qs[l/8] |= vi0 << (8*(l & 7));
qs[l/8] |= vi1 << (8*(l & 7) + 4);
qs[l/8] |= xi0 << (8*(l & 7));
qs[l/8] |= xi1 << (8*(l & 7) + 4);
}
memcpy(y[i].qs, qs, qk/2);
@ -921,14 +922,14 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
uint64_t qs[QK4_1 / 16] = {0};
for (int l = 0; l < qk/2; ++l) {
const float v0 = (x[0 + l] - min)*id;
const float v1 = (x[qk/2 + l] - min)*id;
const float x0 = (x[0 + l] - min)*id;
const float x1 = (x[qk/2 + l] - min)*id;
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 0.5f));
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 0.5f));
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
qs[l/8] |= vi0 << (8*(l & 7));
qs[l/8] |= vi1 << (8*(l & 7) + 4);
qs[l/8] |= xi0 << (8*(l & 7));
qs[l/8] |= xi1 << (8*(l & 7) + 4);
}
memcpy(y[i].qs, qs, qk/2);
@ -968,14 +969,14 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
uint64_t qs[QK4_2 / 16] = {0};
for (int l = 0; l < qk/2; ++l) {
const float v0 = x[i*qk + 0 + l]*id;
const float v1 = x[i*qk + qk/2 + l]*id;
const float x0 = x[i*qk + 0 + l]*id;
const float x1 = x[i*qk + qk/2 + l]*id;
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
const uint64_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
const uint64_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
qs[l/8] |= vi0 << (8*(l & 7));
qs[l/8] |= vi1 << (8*(l & 7) + 4);
qs[l/8] |= xi0 << (8*(l & 7));
qs[l/8] |= xi1 << (8*(l & 7) + 4);
}
memcpy(y[i].qs, qs, qk/2);
@ -1015,18 +1016,18 @@ static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * r
uint64_t qs[QK5_0 / 16] = {0};
for (int l = 0; l < qk/2; ++l) {
const float v0 = x[i*qk + 0 + l]*id;
const float v1 = x[i*qk + qk/2 + l]*id;
const float x0 = x[i*qk + 0 + l]*id;
const float x1 = x[i*qk + qk/2 + l]*id;
const uint64_t vi0 = MIN(31, (int8_t)(v0 + 16.5f));
const uint64_t vi1 = MIN(31, (int8_t)(v1 + 16.5f));
const uint64_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
const uint64_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
qs[l/8] |= vi0 << (8*(l & 7));
qs[l/8] |= vi1 << (8*(l & 7) + 4);
qs[l/8] |= xi0 << (8*(l & 7));
qs[l/8] |= xi1 << (8*(l & 7) + 4);
// get the 5-th bit and store it in qh at the right position
qh |= ((vi0 & 0x10) >> 4) << (l + 0);
qh |= ((vi1 & 0x10) >> 4) << (l + qk/2);
qh |= ((xi0 & 0x10) >> 4) << (l + 0);
qh |= ((xi1 & 0x10) >> 4) << (l + qk/2);
}
memcpy( y[i].qs, qs, qk/2);
@ -1447,15 +1448,15 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
const int nb = k / qk;
uint64_t qs[QK4_0 / 8];
for (int i = 0; i < nb; i++) {
const float d = x[i].d;
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
for (int j = 0; j < qk/2; ++j) {
const int x0 = (x[i].qs[j] & 0xf) - 8;
const int x1 = (x[i].qs[j] >> 4) - 8;
for (int l = 0; l < qk; ++l) {
y[i*qk + l] = (qsp[l] - 8)*d;
y[i*qk + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d;
}
}
}
@ -1468,21 +1469,22 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
const int nb = k / qk;
uint64_t qs[QK4_0 / 8];
for (int i = 0; i < nb; i++) {
const float d = x[i].d;
const float m = x[i].m;
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
for (int j = 0; j < qk/2; ++j) {
const int x0 = (x[i].qs[j] & 0xf);
const int x1 = (x[i].qs[j] >> 4);
for (int l = 0; l < qk; ++l) {
y[i*qk + l] = qsp[l]*d + m;
y[i*qk + j + 0 ] = x0*d + m;
y[i*qk + j + qk/2] = x1*d + m;
}
}
}
static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict y, int k) {
// BORKEN !!!
static const int qk = QK4_2;
assert(qk / 16 == 0);
@ -1495,7 +1497,7 @@ static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict
for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
const uint8_t * qsp = b4_from_nibbles_64(qk, x[i].qs, qs);
for (int l = 0; l < qk; ++l) {
y[i*qk + l] = (qsp[l] - 8)*d;
@ -1511,20 +1513,21 @@ static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict
const int nb = k / qk;
uint64_t qs[QK5_0 / 8];
for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));
const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
for (int j = 0; j < qk/2; ++j) {
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
for (int l = 0; l < qk; ++l) {
const uint8_t vh = ((qh & (1u << l)) >> l) << 4;
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
y[i*qk + l] = ((qsp[l] | vh) - 16)*d;
y[i*qk + j + 0 ] = x0*d;
y[i*qk + j + qk/2] = x1*d;
}
}
}
@ -2388,17 +2391,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
// scalar
float sumf = 0.0;
uint64_t qs[QK8_0 / 8];
for (int i = 0; i < nb; i++) {
// unpack nibbles into bytes
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
const int8_t * py = y[i].qs;
const int8_t * py = y[i].qs;
int sumi = 0;
for (int j = 0; j < qk; ++j) {
sumi += (px[j] - 8) * py[j];
for (int j = 0; j < qk/2; ++j) {
const int v0 = (x[i].qs[j] & 0xf) - 8;
const int v1 = (x[i].qs[j] >> 4) - 8;
sumi += (v0 * py[j]) + (v1 * py[j + qk/2]);
}
sumf += (x[i].d*y[i].d)*sumi;
@ -2513,16 +2515,16 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
// scalar
float sumf = 0.0;
uint64_t qs[QK8_1 / 8];
for (int i = 0; i < nb; i++) {
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
const int8_t * py = y[i].qs;
const int8_t * py = y[i].qs;
int sumi = 0;
for (int j = 0; j < qk; ++j) {
sumi += px[j]*py[j];
for (int j = 0; j < qk/2; ++j) {
const int v0 = (x[i].qs[j] & 0xf);
const int v1 = (x[i].qs[j] >> 4);
sumi += (v0 * py[j]) + (v1 * py[j + qk/2]);
}
sumf += (x[i].d*y[i].d)*sumi + x[i].m*(y[i].s0 + y[i].s1);
@ -2847,22 +2849,22 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
// scalar
float sumf = 0.0;
uint64_t qs[QK8_0 / 8];
for (int i = 0; i < nb; i++) {
// unpack nibbles into bytes
const uint8_t * px = bytes_from_nibbles_64(qk, x[i].qs, qs);
const int8_t * py = y[i].qs;
const int8_t * py = y[i].qs;
uint32_t qh;
memcpy(&qh, x[i].qh, sizeof(qh));
int sumi = 0;
for (int j = 0; j < qk; ++j) {
const int xh = ((qh & (1u << j)) >> j) << 4;
for (int j = 0; j < qk/2; ++j) {
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
sumi += ((px[j] | xh) - 16)*py[j];
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
sumi += (x0 * py[j]) + (x1 * py[j + qk/2]);
}
sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi;