diff --git a/Makefile b/Makefile index a1b99c6f9..e7470d51a 100644 --- a/Makefile +++ b/Makefile @@ -133,7 +133,7 @@ $(info I CC: $(CCV)) $(info I CXX: $(CXXV)) $(info ) -default: main quantize perplexity embedding +default: main quantize quantize-stats perplexity embedding # # Build library diff --git a/ggml.c b/ggml.c index cf6a81f43..54b9f764f 100644 --- a/ggml.c +++ b/ggml.c @@ -575,7 +575,7 @@ uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { // blocks of QK elements // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors) typedef struct { - float d; // delta + float d; // delta uint8_t qs[QK / 2]; // nibbles / quants } block_q4_0; static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding"); @@ -584,12 +584,19 @@ static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block si // blocks of QK elements // represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors) typedef struct { - float d; - float m; + float d; // delta + float m; // min uint8_t qs[QK / 2]; // nibbles / quants } block_q4_1; static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding"); +typedef struct { + float d; // delta + int8_t qs[QK]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(float) + QK, "wrong q8_0 block size/padding"); + + // reference implementation for deterministic creation of model files static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { assert(k % QK == 0); @@ -1042,6 +1049,157 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int #endif } +// reference implementation for deterministic creation of model files +static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { + assert(k % QK == 0); + const int nb = k / QK; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int l = 0; l < QK; l++) { + const float v = x[i*QK + l]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + for (int l = 0; l < QK; ++l) { + const float v = x[i*QK + l]*id; + y[i].qs[l] = roundf(v); + } + } +} + +static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { + assert(k % QK == 0); + const int nb = k / QK; + + block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); + for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]); + + for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]); + for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); + for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + for (int l = 0; l < 8; l++) { + const float32x4_t v = vmulq_n_f32(srcv[l], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3); + } + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = d; + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#else + // scalar + quantize_row_q8_0_reference(x, y, k); +#endif +} + static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) { assert(k % QK == 0); const int nb = k / QK; @@ -2344,12 +2502,12 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); sum00 += x0->m*y0->m; - sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h)); - sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h)); + sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h)); + sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h)); sum00 += x1->m*y1->m; - sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h)); - sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h)); + sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h)); + sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h)); #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t @@ -2417,6 +2575,209 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest *s = sumf; } +static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK; + + assert(n % QK == 0); + assert(nb % 2 == 0); + + const block_q4_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + + float sumf = 0.0; + +#if defined(__ARM_NEON) + float sum0 = 0.0f; + float sum1 = 0.0f; + + for (int i = 0; i < nb; i += 2) { + const block_q4_0 * restrict x0 = &x[i + 0]; + const block_q4_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + // interleave + const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h); + const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h); + const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h); + const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h); + +#if defined(__ARM_FEATURE_DOTPROD) + // dot product into int32x4_t + int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls); + int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls); + + p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs); + p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs); + + sum0 += x0->d*y0->d*vaddvq_s32(p_0); + sum1 += x1->d*y1->d*vaddvq_s32(p_1); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); + + const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h); + const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h); + + const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h); + const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h); + + const int16x8_t p_0 = vaddq_s16(pl_0, ph_0); + const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); + + sum0 += x0->d*y0->d*vaddvq_s16(p_0); + sum1 += x1->d*y1->d*vaddvq_s16(p_1); +#endif + } + + sumf = sum0 + sum1; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); + + __m256i bx = bytesFromNibbles(x[i].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + bx = _mm256_sub_epi8( bx, off ); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(bx, bx); + + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(by, bx); + + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + + const __m256i ones = _mm256_set1_epi16(1); + __m256i xy_q = _mm256_madd_epi16(ones, dot); + + /* Convert to vectore of 8 int32_t to 8 floats */ + __m256 q = _mm256_cvtepi32_ps( xy_q ); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps( d, q, acc ); + } + + // Return horizontal sum of the acc vector + __m128 res = _mm256_extractf128_ps( acc, 1 ); + res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); + res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); + res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); + + sumf = _mm_cvtss_f32( res ); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); + + __m128i i32[2]; + for (int j = 0; j < 2; ++j) { + // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes + __m128i bx = bytesFromNibbles( x[i].qs + 8*j ); + __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j)); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m128i off = _mm_set1_epi8( 8 ); + bx = _mm_sub_epi8( bx, off ); + + // Get absolute values of x vectors + const __m128i ax = _mm_sign_epi8(bx, bx); + + // Sign the values of the y vectors + const __m128i sy = _mm_sign_epi8(by, bx); + + // Perform multiplication and create 16-bit values + const __m128i dot = _mm_maddubs_epi16(ax, sy); + + const __m128i ones = _mm_set1_epi16(1); + i32[j] = _mm_madd_epi16(ones, dot); + } + + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] )); + // Apply the scale, and accumulate + acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); + } + + // Return horizontal sum of the acc vector + __m128 res = _mm256_extractf128_ps( acc, 1 ); + res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); + res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); + res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); + + sumf = _mm_cvtss_f32( res ); +#else + // scalar + for (int i = 0; i < nb; i++) { + const float d0 = x[i].d; + const float d1 = y[i].d; + + const uint8_t * restrict p0 = x[i].qs; + const int8_t * restrict p1 = y[i].qs; + + int sumi = 0; + for (int j = 0; j < QK/2; j++) { + const uint8_t v0 = p0[j]; + + const int i0 = (int8_t) (v0 & 0xf) - 8; + const int i1 = (int8_t) (v0 >> 4) - 8; + + const int i2 = p1[2*j + 0]; + const int i3 = p1[2*j + 1]; + + sumi += i0*i2 + i1*i3; + } + sumf += d0*d1*sumi; + } +#endif + + *s = sumf; +} + // compute GGML_VEC_DOT_UNROLL dot products at once // xs - x row stride in bytes inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { @@ -2663,22 +3024,24 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F16] = 1, [GGML_TYPE_Q4_0] = QK, [GGML_TYPE_Q4_1] = QK, + [GGML_TYPE_Q8_0] = QK, [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, }; -static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated"); static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = sizeof(float), [GGML_TYPE_F16] = sizeof(ggml_fp16_t), [GGML_TYPE_Q4_0] = sizeof(block_q4_0), [GGML_TYPE_Q4_1] = sizeof(block_q4_1), + [GGML_TYPE_Q8_0] = sizeof(block_q8_0), [GGML_TYPE_I8] = sizeof(int8_t), [GGML_TYPE_I16] = sizeof(int16_t), [GGML_TYPE_I32] = sizeof(int32_t), }; -static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated"); +static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated"); static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { @@ -2686,11 +3049,12 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = { [GGML_TYPE_F16] = "f16", [GGML_TYPE_Q4_0] = "q4_0", [GGML_TYPE_Q4_1] = "q4_1", + [GGML_TYPE_Q8_0] = "q8_0", [GGML_TYPE_I8] = "i8", [GGML_TYPE_I16] = "i16", [GGML_TYPE_I32] = "i32", }; -static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_NAME is outdated"); +static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated"); static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "NONE", @@ -3371,6 +3735,10 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q8_0: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -3431,6 +3799,10 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q8_0: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -3485,6 +3857,10 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q8_0: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3529,6 +3905,10 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q8_0: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3571,6 +3951,10 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q8_0: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3615,6 +3999,10 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { { GGML_ASSERT(false); } break; + case GGML_TYPE_Q8_0: + { + GGML_ASSERT(false); + } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -5437,6 +5825,7 @@ static void ggml_compute_forward_dup( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5518,6 +5907,7 @@ static void ggml_compute_forward_add( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5570,6 +5960,7 @@ static void ggml_compute_forward_sub( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5622,6 +6013,7 @@ static void ggml_compute_forward_mul( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5674,6 +6066,7 @@ static void ggml_compute_forward_div( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5722,6 +6115,7 @@ static void ggml_compute_forward_sqr( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5770,6 +6164,7 @@ static void ggml_compute_forward_sqrt( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5828,6 +6223,7 @@ static void ggml_compute_forward_sum( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5905,6 +6301,7 @@ static void ggml_compute_forward_mean( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -5969,6 +6366,7 @@ static void ggml_compute_forward_repeat( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6017,6 +6415,7 @@ static void ggml_compute_forward_abs( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6065,6 +6464,7 @@ static void ggml_compute_forward_sgn( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6113,6 +6513,7 @@ static void ggml_compute_forward_neg( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6161,6 +6562,7 @@ static void ggml_compute_forward_step( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6209,6 +6611,7 @@ static void ggml_compute_forward_relu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6274,6 +6677,7 @@ static void ggml_compute_forward_gelu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6341,6 +6745,7 @@ static void ggml_compute_forward_silu( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6427,6 +6832,7 @@ static void ggml_compute_forward_norm( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6507,6 +6913,7 @@ static void ggml_compute_forward_rms_norm( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -6908,14 +7315,17 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .dequantize_row_q = dequantize_row_q4_0, .quantize_row_q = quantize_row_q4_0, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference, - .vec_dot_q = ggml_vec_dot_q4_0, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = ggml_vec_dot_q4_0_q8_0, }, [GGML_TYPE_Q4_1] = { .dequantize_row_q = dequantize_row_q4_1, .quantize_row_q = quantize_row_q4_1, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference, + .quantize_row_q_dot = quantize_row_q4_1, .vec_dot_q = ggml_vec_dot_q4_1, }, + // TODO: GGML_TYPE_Q8_0 }; // For internal test use @@ -6971,8 +7381,8 @@ static void ggml_compute_forward_mul_mat_q_f32( GGML_ASSERT(ne3 == ne13); const enum ggml_type type = src0->type; - quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q; - vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; + quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot; + vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q; // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]); @@ -7041,12 +7451,12 @@ static void ggml_compute_forward_mul_mat_q_f32( if (params->type == GGML_TASK_INIT) { char * wdata = params->wdata; - const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type]; + const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0]; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = 0; i11 < ne11; ++i11) { - quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); wdata += row_size; } } @@ -7072,7 +7482,7 @@ static void ggml_compute_forward_mul_mat_q_f32( const int ir1 = MIN(ir0 + dr, nr); void * wdata = params->wdata; - const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type]; + const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0]; for (int ir = ir0; ir < ir1; ++ir) { // src0 indices @@ -7120,6 +7530,7 @@ static void ggml_compute_forward_mul_mat( switch (src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: { ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst); } break; @@ -7218,6 +7629,7 @@ static void ggml_compute_forward_scale( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7383,6 +7795,7 @@ static void ggml_compute_forward_get_rows( switch (src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: { ggml_compute_forward_get_rows_q(params, src0, src1, dst); } break; @@ -7472,6 +7885,7 @@ static void ggml_compute_forward_diag_mask_inf( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7566,6 +7980,7 @@ static void ggml_compute_forward_soft_max( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -7749,6 +8164,7 @@ static void ggml_compute_forward_rope( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8017,6 +8433,7 @@ static void ggml_compute_forward_conv_1d_1s( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8285,6 +8702,7 @@ static void ggml_compute_forward_conv_1d_2s( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8770,6 +9188,7 @@ static void ggml_compute_forward_flash_attn( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -8981,6 +9400,7 @@ static void ggml_compute_forward_flash_ff( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -9030,6 +9450,7 @@ static void ggml_compute_forward_map_unary( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -9085,6 +9506,7 @@ static void ggml_compute_forward_map_binary( } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -9914,7 +10336,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } else #endif { - cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type]; + cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0]; } } else { GGML_ASSERT(false); diff --git a/ggml.h b/ggml.h index 617298a95..241e96a19 100644 --- a/ggml.h +++ b/ggml.h @@ -204,6 +204,7 @@ enum ggml_type { GGML_TYPE_F16 = 1, GGML_TYPE_Q4_0 = 2, GGML_TYPE_Q4_1 = 3, + GGML_TYPE_Q8_0 = 4, GGML_TYPE_I8, GGML_TYPE_I16, GGML_TYPE_I32, @@ -836,6 +837,7 @@ typedef struct { dequantize_row_q_t dequantize_row_q; quantize_row_q_t quantize_row_q; quantize_row_q_t quantize_row_q_reference; + quantize_row_q_t quantize_row_q_dot; vec_dot_q_t vec_dot_q; } quantize_fns_t;