2-bit quantizations (llama/4897)

* imatrix: load

* imatrix: WIP

* imatrix: Add Q2_K quantization

* imatrix: also guard against Q2_K_S quantization without importance matrix

* imatrix: guard even more against low-bit quantization misuse

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
pull/1765/head^2
Kawrakow 2024-01-14 09:45:56 +02:00 committed by Georgi Gerganov
parent 654baf693d
commit dabc964d83
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 939 additions and 68 deletions

View File

@ -5,6 +5,8 @@
#include <string.h>
#include <assert.h>
#include <float.h>
#include <stdlib.h> // for qsort
#include <stdio.h> // for GGML_ASSERT
#ifdef __ARM_NEON
@ -1639,6 +1641,241 @@ size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n
return (n/QK_K*sizeof(block_q2_K));
}
static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
float rmin, float rdelta, int nstep, bool use_mad) {
float min = x[0];
float max = x[0];
float sum_w = weights ? weights[0] : x[0]*x[0];
float sum_x = sum_w * x[0];
for (int i = 1; i < n; ++i) {
if (x[i] < min) min = x[i];
if (x[i] > max) max = x[i];
float w = weights ? weights[i] : x[i]*x[i];
sum_w += w;
sum_x += w * x[i];
}
if (min > 0) {
min = 0;
}
if (max <= min) {
for (int i = 0; i < n; ++i) L[i] = 0;
*the_min = -min;
return 0.f;
}
float iscale = nmax/(max - min);
float scale = 1/iscale;
float best_mad = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale*(x[i] - min));
L[i] = MAX(0, MIN(nmax, l));
float diff = scale * L[i] + min - x[i];
diff = use_mad ? fabsf(diff) : diff*diff;
float w = weights ? weights[i] : x[i]*x[i];
best_mad += w * diff;
}
if (nstep < 1) {
*the_min = -min;
return scale;
}
for (int is = 0; is <= nstep; ++is) {
iscale = (rmin + rdelta*is + nmax)/(max - min);
float sum_l = 0, sum_l2 = 0, sum_xl = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale*(x[i] - min));
l = MAX(0, MIN(nmax, l));
Laux[i] = l;
float w = weights ? weights[i] : x[i]*x[i];
sum_l += w*l;
sum_l2 += w*l*l;
sum_xl += w*l*x[i];
}
float D = sum_w * sum_l2 - sum_l * sum_l;
if (D > 0) {
float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
if (this_min > 0) {
this_min = 0;
this_scale = sum_xl / sum_l2;
}
float mad = 0;
for (int i = 0; i < n; ++i) {
float diff = this_scale * Laux[i] + this_min - x[i];
diff = use_mad ? fabsf(diff) : diff*diff;
float w = weights ? weights[i] : x[i]*x[i];
mad += w * diff;
}
if (mad < best_mad) {
for (int i = 0; i < n; ++i) {
L[i] = Laux[i];
}
best_mad = mad;
scale = this_scale;
min = this_min;
}
}
}
*the_min = -min;
return scale;
}
static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, const float * quant_weights) {
float max = 0;
for (int i = 0; i < n; ++i) {
max = MAX(max, x[i]);
}
if (!max) { // all zero
for (int i = 0; i < n; ++i) { L[i] = 0; }
return 0.f;
}
float iscale = nmax / max;
for (int i = 0; i < n; ++i) {
L[i] = nearest_int(iscale * x[i]);
}
float scale = 1/iscale;
float best_mse = 0;
for (int i = 0; i < n; ++i) {
float diff = x[i] - scale*L[i];
float w = quant_weights[i];
best_mse += w*diff*diff;
}
for (int is = -4; is <= 4; ++is) {
if (is == 0) continue;
float iscale_is = (0.1f*is + nmax)/max;
float scale_is = 1/iscale_is;
float mse = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale_is*x[i]);
l = MIN(nmax, l);
float diff = x[i] - scale_is*l;
float w = quant_weights[i];
mse += w*diff*diff;
}
if (mse < best_mse) {
best_mse = mse;
iscale = iscale_is;
}
}
float sumlx = 0;
float suml2 = 0;
for (int i = 0; i < n; ++i) {
int l = nearest_int(iscale * x[i]);
l = MIN(nmax, l);
L[i] = l;
float w = quant_weights[i];
sumlx += w*x[i]*l;
suml2 += w*l*l;
}
for (int itry = 0; itry < 5; ++itry) {
int n_changed = 0;
for (int i = 0; i < n; ++i) {
float w = quant_weights[i];
float slx = sumlx - w*x[i]*L[i];
float sl2 = suml2 - w*L[i]*L[i];
if (slx > 0 && sl2 > 0) {
int new_l = nearest_int(x[i] * sl2 / slx);
new_l = MIN(nmax, new_l);
if (new_l != L[i]) {
slx += w*x[i]*new_l;
sl2 += w*new_l*new_l;
if (slx*slx*suml2 > sumlx*sumlx*sl2) {
L[i] = new_l; sumlx = slx; suml2 = sl2;
++n_changed;
}
}
}
}
if (!n_changed) {
break;
}
}
return sumlx / suml2;
}
static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restrict y, int k, const float * restrict quant_weights) {
GGML_ASSERT(quant_weights);
assert(k % QK_K == 0);
const int nb = k / QK_K;
const bool requantize = true;
uint8_t L[QK_K];
uint8_t Laux[16];
float mins[QK_K/16];
float scales[QK_K/16];
float sw[QK_K/16];
float weight[QK_K/16];
uint8_t Ls[QK_K/16], Lm[QK_K/16];
for (int i = 0; i < nb; i++) {
memset(sw, 0, QK_K/16*sizeof(float));
float sumx2 = 0;
for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
float sigma2 = sumx2/QK_K;
for (int j = 0; j < QK_K/16; ++j) {
const float * restrict qw = quant_weights + QK_K * i + 16*j;
for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
for (int l = 0; l < 16; ++l) sw[j] += weight[l];
scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
}
float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
y[i].d = GGML_FP32_TO_FP16(dm);
y[i].dmin = GGML_FP32_TO_FP16(mm);
dm = GGML_FP16_TO_FP32(y[i].d);
mm = GGML_FP16_TO_FP32(y[i].dmin);
for (int j = 0; j < QK_K/16; ++j) {
y[i].scales[j] = Ls[j] | (Lm[j] << 4);
}
if (requantize) {
for (int j = 0; j < QK_K/16; ++j) {
const float d = dm * (y[i].scales[j] & 0xF);
if (!d) continue;
const float m = mm * (y[i].scales[j] >> 4);
for (int ii = 0; ii < 16; ++ii) {
int l = nearest_int((x[16*j + ii] + m)/d);
l = MAX(0, MIN(3, l));
L[16*j + ii] = l;
}
}
}
#if QK_K == 256
for (int j = 0; j < QK_K; j += 128) {
for (int l = 0; l < 32; ++l) {
y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
}
}
#else
for (int l = 0; l < 16; ++l) {
y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
}
#endif
x += QK_K;
}
}
size_t quantize_q2_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
(void)hist;
int row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
if (!quant_weights) {
quantize_row_q2_K_reference(src, dst, nrow*n_per_row);
}
else {
char * qrow = (char *)dst;
for (int row = 0; row < nrow; ++row) {
quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += row_size;
}
}
return nrow * row_size;
}
//========================= 3-bit (de)-quantization
void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
@ -2584,14 +2821,6 @@ static const uint8_t ksigns_iq2xs[128] = {
static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) {
(void)x;
(void)y;
(void)k;
assert(k % QK_K == 0);
//fprintf(stderr, "=========================== %s: not implemented\n", __func__);
}
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
@ -2618,33 +2847,8 @@ void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y
}
}
void quantize_row_iq2_xxs(const float * restrict x, void * restrict vy, int k) {
assert(k % QK_K == 0);
block_iq2_xxs * restrict y = vy;
quantize_row_iq2_xxs_reference(x, y, k);
}
size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist) {
assert(k % QK_K == 0);
(void)hist; // TODO: collect histograms
for (int j = 0; j < n; j += k) {
block_iq2_xxs * restrict y = (block_iq2_xxs *)dst + j/QK_K;
quantize_row_iq2_xxs_reference(src + j, y, k);
}
return (n/QK_K*sizeof(block_iq2_xxs));
}
// ====================== 2.3125 bpw (de)-quantization
void quantize_row_iq2_xs_reference(const float * restrict x, block_iq2_xs * restrict y, int k) {
(void)x;
(void)y;
(void)k;
assert(k % QK_K == 0);
//fprintf(stderr, "=========================== %s: not implemented\n", __func__);
}
void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
@ -2670,23 +2874,6 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
}
}
void quantize_row_iq2_xs(const float * restrict x, void * restrict vy, int k) {
assert(k % QK_K == 0);
block_iq2_xs * restrict y = vy;
quantize_row_iq2_xs_reference(x, y, k);
}
size_t ggml_quantize_iq2_xs(const float * src, void * dst, int n, int k, int64_t * hist) {
assert(k % QK_K == 0);
(void)hist; // TODO: collect histograms
for (int j = 0; j < n; j += k) {
block_iq2_xs * restrict y = (block_iq2_xs *)dst + j/QK_K;
quantize_row_iq2_xs_reference(src + j, y, k);
}
return (n/QK_K*sizeof(block_iq2_xs));
}
//===================================== Q8_K ==============================================
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@ -7730,3 +7917,666 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
*s = 0.125f * sumf;
#endif
}
// ================================ IQ2 quantization =============================================
typedef struct {
uint64_t * grid;
int * map;
uint16_t * neighbours;
} iq2_entry_t;
static iq2_entry_t iq2_data[2] = {
{NULL, NULL, NULL},
{NULL, NULL, NULL},
};
static inline int iq2_data_index(int grid_size) {
GGML_ASSERT(grid_size == 256 || grid_size == 512);
return grid_size == 256 ? 0 : 1;
}
static int iq2_compare_func(const void * left, const void * right) {
const int * l = (const int *)left;
const int * r = (const int *)right;
return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
}
static void q2xs_init_impl(int grid_size) {
const int gindex = iq2_data_index(grid_size);
if (iq2_data[gindex].grid) {
return;
}
static const uint16_t kgrid_256[256] = {
0, 2, 5, 8, 10, 17, 20, 32, 34, 40, 42, 65, 68, 80, 88, 97,
100, 128, 130, 138, 162, 257, 260, 272, 277, 320, 388, 408, 512, 514, 546, 642,
1025, 1028, 1040, 1057, 1060, 1088, 1090, 1096, 1120, 1153, 1156, 1168, 1188, 1280, 1282, 1288,
1312, 1350, 1385, 1408, 1425, 1545, 1552, 1600, 1668, 1700, 2048, 2053, 2056, 2068, 2088, 2113,
2116, 2128, 2130, 2184, 2308, 2368, 2562, 2580, 4097, 4100, 4112, 4129, 4160, 4192, 4228, 4240,
4245, 4352, 4360, 4384, 4432, 4442, 4480, 4644, 4677, 5120, 5128, 5152, 5157, 5193, 5248, 5400,
5474, 5632, 5654, 6145, 6148, 6160, 6208, 6273, 6400, 6405, 6560, 6737, 8192, 8194, 8202, 8260,
8289, 8320, 8322, 8489, 8520, 8704, 8706, 9217, 9220, 9232, 9280, 9302, 9472, 9537, 9572, 9872,
10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516,
16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561,
17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488,
20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545,
22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874,
25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856,
33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
};
static const uint16_t kgrid_512[512] = {
0, 2, 5, 8, 10, 17, 20, 22, 25, 32, 34, 37, 40, 65, 68, 70,
73, 80, 82, 85, 88, 97, 100, 128, 130, 133, 136, 145, 148, 153, 160, 257,
260, 262, 265, 272, 274, 277, 280, 282, 289, 292, 320, 322, 325, 328, 337, 340,
352, 360, 385, 388, 400, 512, 514, 517, 520, 529, 532, 544, 577, 580, 592, 597,
640, 650, 1025, 1028, 1030, 1033, 1040, 1042, 1045, 1048, 1057, 1060, 1088, 1090, 1093, 1096,
1105, 1108, 1110, 1120, 1153, 1156, 1168, 1280, 1282, 1285, 1288, 1297, 1300, 1312, 1345, 1348,
1360, 1377, 1408, 1537, 1540, 1552, 1574, 1600, 1602, 1668, 2048, 2050, 2053, 2056, 2058, 2065,
2068, 2080, 2085, 2113, 2116, 2128, 2136, 2176, 2208, 2218, 2305, 2308, 2320, 2368, 2433, 2441,
2560, 2592, 2600, 2710, 2720, 4097, 4100, 4102, 4105, 4112, 4114, 4117, 4120, 4129, 4132, 4160,
4162, 4165, 4168, 4177, 4180, 4192, 4202, 4225, 4228, 4240, 4352, 4354, 4357, 4360, 4369, 4372,
4384, 4417, 4420, 4432, 4480, 4500, 4502, 4609, 4612, 4614, 4624, 4672, 4704, 5120, 5122, 5125,
5128, 5137, 5140, 5152, 5185, 5188, 5193, 5200, 5220, 5248, 5377, 5380, 5392, 5440, 5632, 5652,
5705, 6145, 6148, 6160, 6162, 6208, 6228, 6278, 6400, 6405, 6502, 6737, 6825, 8192, 8194, 8197,
8200, 8202, 8209, 8212, 8224, 8257, 8260, 8272, 8320, 8352, 8449, 8452, 8464, 8512, 8520, 8549,
8704, 8738, 8832, 8872, 9217, 9220, 9232, 9257, 9280, 9472, 9537, 9554, 9625, 9729, 9754, 9894,
10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388,
16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480,
16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773,
16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473,
17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436,
18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497,
20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162,
21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528,
22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745,
24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234,
32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025,
33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810,
33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984,
35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462,
37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960,
40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
};
const int kmap_size = 43692;
const int nwant = 2;
const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
uint64_t * kgrid_q2xs;
int * kmap_q2xs;
uint16_t * kneighbors_q2xs;
printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t));
for (int k = 0; k < grid_size; ++k) {
int8_t * pos = (int8_t *)(the_grid + k);
for (int i = 0; i < 8; ++i) {
int l = (kgrid[k] >> 2*i) & 0x3;
pos[i] = 2*l + 1;
}
}
kgrid_q2xs = the_grid;
iq2_data[gindex].grid = the_grid;
kmap_q2xs = (int *)malloc(kmap_size*sizeof(int));
iq2_data[gindex].map = kmap_q2xs;
for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1;
uint64_t aux64;
uint8_t * aux8 = (uint8_t *)&aux64;
for (int i = 0; i < grid_size; ++i) {
aux64 = kgrid_q2xs[i];
uint16_t index = 0;
for (int k=0; k<8; ++k) {
uint16_t q = (aux8[k] - 1)/2;
index |= (q << 2*k);
}
kmap_q2xs[index] = i;
}
int8_t pos[8];
int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
int num_neighbors = 0, num_not_in_map = 0;
for (int i = 0; i < kmap_size; ++i) {
if (kmap_q2xs[i] >= 0) continue;
++num_not_in_map;
for (int k = 0; k < 8; ++k) {
int l = (i >> 2*k) & 0x3;
pos[k] = 2*l + 1;
}
for (int j = 0; j < grid_size; ++j) {
const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
int d2 = 0;
for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
dist2[2*j+0] = d2;
dist2[2*j+1] = j;
}
qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
int n = 0; int d2 = dist2[0];
int nhave = 1;
for (int j = 0; j < grid_size; ++j) {
if (dist2[2*j] > d2) {
if (nhave == nwant) break;
d2 = dist2[2*j];
++nhave;
}
++n;
}
num_neighbors += n;
}
printf("%s: %d neighbours in total\n", __func__, num_neighbors);
kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
iq2_data[gindex].neighbours = kneighbors_q2xs;
int counter = 0;
for (int i = 0; i < kmap_size; ++i) {
if (kmap_q2xs[i] >= 0) continue;
for (int k = 0; k < 8; ++k) {
int l = (i >> 2*k) & 0x3;
pos[k] = 2*l + 1;
}
for (int j = 0; j < grid_size; ++j) {
const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
int d2 = 0;
for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
dist2[2*j+0] = d2;
dist2[2*j+1] = j;
}
qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
kmap_q2xs[i] = -(counter + 1);
int d2 = dist2[0];
uint16_t * start = &kneighbors_q2xs[counter++];
int n = 0, nhave = 1;
for (int j = 0; j < grid_size; ++j) {
if (dist2[2*j] > d2) {
if (nhave == nwant) break;
d2 = dist2[2*j];
++nhave;
}
kneighbors_q2xs[counter++] = dist2[2*j+1];
++n;
}
*start = n;
}
free(dist2);
}
void ggml_init_iq2_quantization(enum ggml_type type) {
if (type == GGML_TYPE_IQ2_XXS) {
q2xs_init_impl(256);
}
else if (type == GGML_TYPE_IQ2_XS) {
q2xs_init_impl(512);
}
else {
fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
}
}
static void q2xs_deinit_impl(int grid_size) {
GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024);
const int gindex = iq2_data_index(grid_size);
if (iq2_data[gindex].grid) {
free(iq2_data[gindex].grid); iq2_data[gindex].grid = NULL;
free(iq2_data[gindex].map); iq2_data[gindex].map = NULL;
free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL;
}
}
void ggml_deinit_iq2_quantization(enum ggml_type type) {
if (type == GGML_TYPE_IQ2_XXS) {
q2xs_deinit_impl(256);
}
else if (type == GGML_TYPE_IQ2_XS) {
q2xs_deinit_impl(512);
}
else {
fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
}
}
static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
int num_neighbors = neighbours[0];
GGML_ASSERT(num_neighbors > 0);
float best_d2 = FLT_MAX;
int grid_index = -1;
for (int j = 1; j <= num_neighbors; ++j) {
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
float d2 = 0;
for (int i = 0; i < 8; ++i) {
float q = pg[i];
float diff = scale*q - xval[i];
d2 += weight[i]*diff*diff;
}
if (d2 < best_d2) {
best_d2 = d2; grid_index = neighbours[j];
}
}
GGML_ASSERT(grid_index >= 0);
const int8_t * pg = (const int8_t *)(grid + grid_index);
for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
return grid_index;
}
static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
const int gindex = iq2_data_index(256);
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
const int * kmap_q2xs = iq2_data[gindex].map;
const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
GGML_ASSERT(quant_weights);
GGML_ASSERT(kgrid_q2xs);
GGML_ASSERT(kmap_q2xs);
GGML_ASSERT(kneighbors_q2xs);
GGML_ASSERT(n%QK_K == 0);
const int kMaxQ = 3;
const int nbl = n/256;
block_iq2_xxs * y = vy;
float scales[QK_K/32];
float weight[32];
float xval[32];
int8_t L[32];
int8_t Laux[32];
float waux[32];
bool is_on_grid[4];
bool is_on_grid_aux[4];
uint8_t block_signs[4];
uint32_t q2[2*(QK_K/32)];
for (int ibl = 0; ibl < nbl; ++ibl) {
y[ibl].d = GGML_FP32_TO_FP16(0.f);
memset(q2, 0, QK_K/4);
float max_scale = 0;
const float * xbl = x + QK_K*ibl;
float sumx2 = 0;
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
float sigma2 = sumx2/QK_K;
for (int ib = 0; ib < QK_K/32; ++ib) {
const float * xb = xbl + 32*ib;
const float * qw = quant_weights + QK_K*ibl + 32*ib;
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
for (int k = 0; k < 4; ++k) {
int nflip = 0;
uint8_t s = 0;
for (int i = 0; i < 8; ++i) {
if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
else {
xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
}
}
if (nflip%2) {
int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
for (int i = 1; i < 8; ++i) {
float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
if (ax < min) {
min = ax; imin = i;
}
}
xval[8*k+imin] = -xval[8*k+imin];
s ^= (1 << imin);
}
block_signs[k] = s & 127;
}
float max = xval[0];
for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
if (!max) {
scales[ib] = 0;
memset(L, 0, 32);
continue;
}
float best = 0;
float scale = max/(2*kMaxQ-1);
for (int is = -9; is <= 9; ++is) {
float id = (2*kMaxQ-1+is*0.1f)/max;
float this_scale = 1/id;
for (int k = 0; k < 4; ++k) {
for (int i = 0; i < 8; ++i) {
int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
}
uint16_t u = 0;
for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
int grid_index = kmap_q2xs[u];
is_on_grid_aux[k] = true;
if (grid_index < 0) {
is_on_grid_aux[k] = false;
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
}
}
float sumqx = 0, sumq2 = 0;
for (int i = 0; i < 32; ++i) {
float w = weight[i];
float q = 2*Laux[i] + 1;
sumqx += w*xval[i]*q;
sumq2 += w*q*q;
}
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
scale = sumqx/sumq2; best = scale*sumqx;
for (int i = 0; i < 32; ++i) L[i] = Laux[i];
for (int k = 0; k < 4; ++k) is_on_grid[k] = is_on_grid_aux[k];
}
}
int n_not_ongrid = 0;
for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
if (n_not_ongrid > 0 && scale > 0) {
float id = 1/scale;
for (int k = 0; k < 4; ++k) {
if (is_on_grid[k]) continue;
uint16_t u = 0;
for (int i = 0; i < 8; ++i) {
int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
l = MAX(0, MIN(kMaxQ-1, l));
u |= (l << 2*i);
}
int grid_index = kmap_q2xs[u];
if (grid_index < 0) {
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
}
const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index);
for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2;
}
float sumqx = 0, sumq2 = 0;
for (int i = 0; i < 32; ++i) {
float w = weight[i];
float q = 2*L[i] + 1;
sumqx += w*xval[i]*q;
sumq2 += w*q*q;
}
if (sumq2 > 0) scale = sumqx/sumq2;
}
if (scale < 0) {
// This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
// and correspondingly flip quant signs.
scale = -scale;
for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
}
for (int k = 0; k < 4; ++k) {
uint16_t u = 0;
for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
int grid_index = kmap_q2xs[u];
if (grid_index < 0) {
printf("Oops: found point %u not on grid:", u);
for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
printf("\n");
GGML_ASSERT(false);
}
q2[2*ib+0] |= (grid_index << 8*k);
q2[2*ib+1] |= (block_signs[k] << 7*k);
}
GGML_ASSERT(scale >= 0);
scales[ib] = scale;
max_scale = MAX(max_scale, scale);
}
if (!max_scale) {
memset(y[ibl].qs, 0, QK_K/4);
continue;
}
float d = max_scale/31;
y[ibl].d = GGML_FP32_TO_FP16(d);
float id = 1/d;
float sumqx = 0, sumq2 = 0;
for (int ib = 0; ib < QK_K/32; ++ib) {
int l = nearest_int(0.5f*(id*scales[ib]-1));
l = MAX(0, MIN(15, l));
q2[2*ib+1] |= ((uint32_t)l << 28);
const float * xb = xbl + 32*ib;
const float * qw = quant_weights + QK_K*ibl + 32*ib;
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
const float db = d * (1 + 2*l);
uint32_t u = 0;
for (int k = 0; k < 4; ++k) {
const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
const float * xk = xb + 8*k;
const float * wk = weight + 8*k;
const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
float best_mse = 0; int best_index = aux8[k];
for (int j = 0; j < 8; ++j) {
float diff = db * grid[j] * signs[j] - xk[j];
best_mse += wk[j] * diff * diff;
}
for (int idx = 0; idx < 256; ++idx) {
grid = (const uint8_t *)(kgrid_q2xs + idx);
float mse = 0;
for (int j = 0; j < 8; ++j) {
float diff = db * grid[j] * signs[j] - xk[j];
mse += wk[j] * diff * diff;
}
if (mse < best_mse) {
best_mse = mse; best_index = idx;
}
}
u |= (best_index << 8*k);
grid = (const uint8_t *)(kgrid_q2xs + best_index);
//grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
for (int j = 0; j < 8; ++j) {
float q = db * grid[j] * signs[j];
sumqx += wk[j] * q * xk[j];
sumq2 += wk[j] * q * q;
}
}
q2[2*ib] = u;
if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
}
memcpy(y[ibl].qs, q2, QK_K/4);
}
}
static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
const int gindex = iq2_data_index(512);
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
const int * kmap_q2xs = iq2_data[gindex].map;
const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
GGML_ASSERT(quant_weights);
GGML_ASSERT(kmap_q2xs);
GGML_ASSERT(kgrid_q2xs);
GGML_ASSERT(kneighbors_q2xs);
GGML_ASSERT(n%QK_K == 0);
const int kMaxQ = 3;
const int nbl = n/256;
block_iq2_xs * y = vy;
float scales[QK_K/16];
float weight[16];
float xval[16];
int8_t L[16];
int8_t Laux[16];
float waux[16];
bool is_on_grid[2];
bool is_on_grid_aux[2];
uint8_t block_signs[2];
uint16_t q2[2*(QK_K/16)];
for (int ibl = 0; ibl < nbl; ++ibl) {
y[ibl].d = GGML_FP32_TO_FP16(0.f);
memset(q2, 0, QK_K/4);
memset(y[ibl].scales, 0, QK_K/32);
float max_scale = 0;
const float * xbl = x + QK_K*ibl;
float sumx2 = 0;
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
float sigma2 = sumx2/QK_K;
for (int ib = 0; ib < QK_K/16; ++ib) {
const float * xb = xbl + 16*ib;
const float * qw = quant_weights + QK_K*ibl + 16*ib;
for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);
for (int k = 0; k < 2; ++k) {
int nflip = 0;
uint8_t s = 0;
for (int i = 0; i < 8; ++i) {
if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
else {
xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
}
}
if (nflip%2) {
int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
for (int i = 1; i < 8; ++i) {
float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
if (ax < min) {
min = ax; imin = i;
}
}
xval[8*k+imin] = -xval[8*k+imin];
s ^= (1 << imin);
}
block_signs[k] = s & 127;
}
float max = xval[0];
for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
if (!max) {
scales[ib] = 0;
memset(L, 0, 16);
continue;
}
float best = 0;
float scale = max/(2*kMaxQ-1);
is_on_grid[0] = is_on_grid[1] = true;
for (int is = -9; is <= 9; ++is) {
float id = (2*kMaxQ-1+is*0.1f)/max;
float this_scale = 1/id;
for (int k = 0; k < 2; ++k) {
for (int i = 0; i < 8; ++i) {
int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
}
uint16_t u = 0;
for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
int grid_index = kmap_q2xs[u];
is_on_grid_aux[k] = true;
if (grid_index < 0) {
is_on_grid_aux[k] = false;
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
}
}
float sumqx = 0, sumq2 = 0;
for (int i = 0; i < 16; ++i) {
float w = weight[i];
float q = 2*Laux[i] + 1;
sumqx += w*xval[i]*q;
sumq2 += w*q*q;
}
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
scale = sumqx/sumq2; best = scale*sumqx;
for (int i = 0; i < 16; ++i) L[i] = Laux[i];
for (int k = 0; k < 2; ++k) is_on_grid[k] = is_on_grid_aux[k];
}
}
int n_not_ongrid = 0;
for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
if (n_not_ongrid > 0 && scale > 0) {
float id = 1/scale;
for (int k = 0; k < 2; ++k) {
if (is_on_grid[k]) continue;
uint16_t u = 0;
for (int i = 0; i < 8; ++i) {
int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
l = MAX(0, MIN(kMaxQ-1, l));
u |= (l << 2*i);
L[8*k + i] = l;
}
int grid_index = kmap_q2xs[u];
if (grid_index < 0) {
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
}
}
float sumqx = 0, sumq2 = 0;
for (int i = 0; i < 16; ++i) {
float w = weight[i];
float q = 2*L[i] + 1;
sumqx += w*xval[i]*q;
sumq2 += w*q*q;
}
if (sumq2 > 0) scale = sumqx/sumq2;
}
if (scale < 0) {
scale = -scale;
for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127;
}
for (int k = 0; k < 2; ++k) {
uint16_t u = 0;
for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
int grid_index = kmap_q2xs[u];
if (grid_index < 0) {
printf("Oops: found point %u not on grid:", u);
for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
printf("\n");
GGML_ASSERT(false);
}
q2[2*ib+k] = grid_index | (block_signs[k] << 9);
}
GGML_ASSERT(scale >= 0);
scales[ib] = scale;
max_scale = MAX(max_scale, scale);
}
if (!max_scale) {
memset(y[ibl].qs, 0, QK_K/4);
continue;
}
float d = max_scale/31;
y[ibl].d = GGML_FP32_TO_FP16(d);
float id = 1/d;
for (int ib = 0; ib < QK_K/16; ++ib) {
int l = nearest_int(0.5f*(id*scales[ib]-1));
l = MAX(0, MIN(15, l));
if (ib%2 == 0) y[ibl].scales[ib/2] = l;
else y[ibl].scales[ib/2] |= (l << 4);
}
memcpy(y[ibl].qs, q2, QK_K/4);
}
}
size_t quantize_iq2_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
(void)hist;
GGML_ASSERT(n_per_row%QK_K == 0);
int nblock = n_per_row/QK_K;
char * qrow = (char *)dst;
for (int row = 0; row < nrow; ++row) {
quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += nblock*sizeof(block_iq2_xxs);
}
return nrow * nblock * sizeof(block_iq2_xxs);
}
size_t quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
(void)hist;
GGML_ASSERT(n_per_row%QK_K == 0);
int nblock = n_per_row/QK_K;
char * qrow = (char *)dst;
for (int row = 0; row < nrow; ++row) {
quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += nblock*sizeof(block_iq2_xs);
}
return nrow * nblock * sizeof(block_iq2_xs);
}

View File

@ -196,8 +196,6 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
void quantize_row_iq2_xs_reference (const float * restrict x, block_iq2_xs * restrict y, int k);
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
@ -212,8 +210,6 @@ void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
void quantize_row_iq2_xs (const float * restrict x, void * restrict y, int k);
// Dequantization
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
@ -246,3 +242,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx,
void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);
//
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
//
size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);

36
ggml.c
View File

@ -585,8 +585,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_iq2_xxs),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
.from_float = quantize_row_iq2_xxs,
.from_float_reference = (ggml_from_float_t) quantize_row_iq2_xxs_reference,
.from_float = NULL,
.from_float_reference = NULL,
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
@ -596,8 +596,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_size = sizeof(block_iq2_xs),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq2_xs,
.from_float = quantize_row_iq2_xs,
.from_float_reference = (ggml_from_float_t) quantize_row_iq2_xs_reference,
.from_float = NULL,
.from_float_reference = NULL,
.vec_dot = ggml_vec_dot_iq2_xs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
@ -18665,8 +18665,11 @@ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t *
return (n/QK8_0*sizeof(block_q8_0));
}
size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start,
int nrows, int n_per_row, int64_t * hist, const float * imatrix) {
(void)imatrix;
size_t result = 0;
int n = nrows * n_per_row;
switch (type) {
case GGML_TYPE_Q4_0:
{
@ -18701,8 +18704,11 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
case GGML_TYPE_Q2_K:
{
GGML_ASSERT(start % QK_K == 0);
block_q2_K * block = (block_q2_K*)dst + start / QK_K;
result = ggml_quantize_q2_K(src + start, block, n, n, hist);
GGML_ASSERT(start % n_per_row == 0);
size_t start_row = start / n_per_row;
size_t row_size = ggml_row_size(type, n_per_row);
result = quantize_q2_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
GGML_ASSERT(result == row_size * nrows);
} break;
case GGML_TYPE_Q3_K:
{
@ -18731,14 +18737,22 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
case GGML_TYPE_IQ2_XXS:
{
GGML_ASSERT(start % QK_K == 0);
block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist);
GGML_ASSERT(start % n_per_row == 0);
GGML_ASSERT(imatrix);
size_t start_row = start / n_per_row;
size_t row_size = ggml_row_size(type, n_per_row);
result = quantize_iq2_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
GGML_ASSERT(result == row_size * nrows);
} break;
case GGML_TYPE_IQ2_XS:
{
GGML_ASSERT(start % QK_K == 0);
block_iq2_xs * block = (block_iq2_xs*)dst + start / QK_K;
result = ggml_quantize_iq2_xs(src + start, block, n, n, hist);
GGML_ASSERT(start % n_per_row == 0);
GGML_ASSERT(imatrix);
size_t start_row = start / n_per_row;
size_t row_size = ggml_row_size(type, n_per_row);
result = quantize_iq2_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
GGML_ASSERT(result == row_size * nrows);
} break;
case GGML_TYPE_F16:
{

9
ggml.h
View File

@ -2067,10 +2067,13 @@ extern "C" {
GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_iq2_xs (const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst,
int start, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
// These are needed for IQ2_XS and IQ2_XXS quantizations
GGML_API void ggml_init_iq2_quantization(enum ggml_type type);
GGML_API void ggml_deinit_iq2_quantization(enum ggml_type type);
//
// Importance matrix