ggml : add ARM_NEON quantize_row_q4_1()

This commit is contained in:
Georgi Gerganov 2023-03-29 22:03:02 +03:00
parent 3b44d30d9b
commit f202ada131
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

56
ggml.c
View file

@ -564,10 +564,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
}
}
#elif __ARM_NEON
uint8_t pp[QK/2];
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
@ -579,7 +576,8 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
amax = MAX(
// absolute max
const float amax = MAX(
MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
@ -593,11 +591,9 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
const int32x4_t vi = vcvtq_s32_f32(vf);
pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
}
memcpy(y[i].qs, pp, sizeof(pp));
}
#elif defined(__AVX2__)
for (int i = 0; i < nb; i++) {
@ -665,7 +661,6 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
}
#elif defined(__wasm_simd128__)
uint8_t pp[QK/2];
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
@ -694,11 +689,9 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
pp[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
}
memcpy(y[i].qs, pp, sizeof(pp));
}
#else
// scalar
@ -750,11 +743,11 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
assert(k % QK == 0);
#if defined(__AVX2__)
const int nb = k / QK;
block_q4_1 * restrict y = vy;
#if defined(__AVX2__)
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
@ -828,6 +821,41 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
__m128i res = packNibbles( i0 );
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
}
#elif __ARM_NEON
for (int i = 0; i < nb; i++) {
float32x4_t srcv[8];
float32x4_t minv[8];
float32x4_t maxv[8];
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l + 4]);
for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l + 1]);
for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l + 2]);
for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l + 4]);
const float min = vminvq_f32(minv[0]);
const float max = vmaxvq_f32(maxv[0]);
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = d;
y[i].m = min;
const float32x4_t minv0 = vdupq_n_f32(min);
for (int l = 0; l < 8; l++) {
const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
const int32x4_t vi = vcvtq_s32_f32(v);
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
}
}
#else
// scalar
quantize_row_q4_1_reference(x, vy, k);