diff --git a/ggml-quants.c b/ggml-quants.c index 5c5f2ce..3d94d16 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -462,6 +462,30 @@ inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { return res; } +// NOTE: not tested +inline static int8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { + int8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + #else #define ggml_int16x8x2_t int16x8x2_t @@ -476,6 +500,7 @@ inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { #define ggml_vld1q_s8_x2 vld1q_s8_x2 #define ggml_vld1q_s8_x4 vld1q_s8_x4 #define ggml_vqtbl1q_s8 vqtbl1q_s8 +#define ggml_vqtbl1q_u8 vqtbl1q_u8 #endif @@ -9488,8 +9513,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v qs += 16; vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16))); - vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); vs.val[0] = vceqq_u8(vs.val[0], mask2); vs.val[1] = vceqq_u8(vs.val[1], mask2); @@ -9497,8 +9522,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v q3s.val[1] = vsubq_s8(vreinterpretq_s8_u8(veorq_u8(vs.val[1], vreinterpretq_u8_u32(aux32x4_1))), vreinterpretq_s8_u8(vs.val[1])); vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | (signs[3] << 16))); - vs.val[1] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); vs.val[0] = vceqq_u8(vs.val[0], mask2); vs.val[1] = vceqq_u8(vs.val[1], mask2);