metal : PP speedup (#3084)

* Minor speed gains for all quantization types

* metal: faster kernel_scale via float4

* Various other speedups for "small" kernels

* metal: faster soft_max vial float4

* metal: faster diagonal infinity

Although, to me it looks like one should simply
fuse scale + diagnonal infinity + soft_max on the
KQtensor.

* Another faster f16 x f32 matrix multiply kernel

* Reverting the diag infinity change

It does work for PP, but somehow it fails for TG.
Need to look more into it.

* metal: add back faster diagonal infinity

This time more carefully

* metal : minor (readibility)

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Kawrakow 2023-09-11 09:30:11 +02:00 committed by GitHub
parent 6eeb4d9083
commit f31b6f4e2d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 220 additions and 104 deletions

View file

@ -63,7 +63,9 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(relu);
GGML_METAL_DECL_KERNEL(gelu);
GGML_METAL_DECL_KERNEL(soft_max);
GGML_METAL_DECL_KERNEL(soft_max_4);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@ -77,6 +79,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@ -218,7 +221,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(relu);
GGML_METAL_ADD_KERNEL(gelu);
GGML_METAL_ADD_KERNEL(soft_max);
GGML_METAL_ADD_KERNEL(soft_max_4);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@ -232,6 +237,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@ -286,7 +292,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(relu);
GGML_METAL_DEL_KERNEL(gelu);
GGML_METAL_DEL_KERNEL(soft_max);
GGML_METAL_DEL_KERNEL(diag_mask_inf);
GGML_METAL_DEL_KERNEL(soft_max_4);
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@ -300,6 +307,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@ -767,7 +775,7 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
const int64_t n = ggml_nelements(dst);
const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
@ -779,7 +787,7 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
@ -799,7 +807,7 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
@ -813,13 +821,16 @@ void ggml_metal_graph_compute(
{
const int nth = 32;
[encoder setComputePipelineState:ctx->pipeline_soft_max];
if (ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
} else {
[encoder setComputePipelineState:ctx->pipeline_soft_max];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
@ -827,14 +838,23 @@ void ggml_metal_graph_compute(
{
const int n_past = ((int32_t *)(dst->op_params))[0];
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
if (ne00%8 == 0) {
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
} else {
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
if (ne00%8 == 0) {
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
}
else {
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
}
} break;
case GGML_OP_MUL_MAT:
{
@ -881,6 +901,7 @@ void ggml_metal_graph_compute(
} else {
int nth0 = 32;
int nth1 = 1;
int nrows = 1;
// use custom matrix x vector kernel
switch (src0t) {
@ -890,8 +911,12 @@ void ggml_metal_graph_compute(
nth1 = 1;
if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
nrows = ne11;
} else {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
nrows = 4;
}
} break;
case GGML_TYPE_Q4_0:
@ -1012,7 +1037,7 @@ void ggml_metal_graph_compute(
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
int64_t ny = (ne11 + 3)/4;
int64_t ny = (ne11 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}

View file

@ -63,18 +63,18 @@ kernel void kernel_mul_row(
}
kernel void kernel_scale(
device const float * src0,
device float * dst,
device const float4 * src0,
device float4 * dst,
constant float & scale,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
}
kernel void kernel_silu(
device const float * src0,
device float * dst,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
float x = src0[tpig];
device const float4 & x = src0[tpig];
dst[tpig] = x / (1.0f + exp(-x));
}
@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
kernel void kernel_gelu(
device const float * src0,
device float * dst,
device const float4 * src0,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
float x = src0[tpig];
device const float4 & x = src0[tpig];
// BEWARE !!!
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
@ -107,7 +107,6 @@ kernel void kernel_soft_max(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
@ -119,64 +118,70 @@ kernel void kernel_soft_max(
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
buf[tpitg[0]] = -INFINITY;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
float lmax = psrc0[tpitg[0]];
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
lmax = MAX(lmax, psrc0[i00]);
}
// reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg[0]/2; i > 0; i /= 2) {
if (tpitg[0] < i) {
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
// the loop, and when that is done, buf[0] has the correct (synchronized) value
//if (tpitg[0] == 0) {
// buf[0] = buf[0];
//}
//threadgroup_barrier(mem_flags::mem_threadgroup);
const float max = buf[0];
const float max = simd_max(lmax);
// parallel sum
buf[tpitg[0]] = 0.0f;
float lsum = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
const float exp_psrc0 = exp(psrc0[i00] - max);
buf[tpitg[0]] += exp_psrc0;
lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// whish to compute it twice.
pdst[i00] = exp_psrc0;
}
// reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg[0]/2; i > 0; i /= 2) {
if (tpitg[0] < i) {
buf[tpitg[0]] += buf[tpitg[0] + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// broadcast - not needed, see above
//// broadcast
//if (tpitg[0] == 0) {
// buf[0] = buf[0];
//}
//threadgroup_barrier(mem_flags::mem_threadgroup);
const float sum = buf[0];
const float sum = simd_sum(lsum);
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
pdst[i00] /= sum;
}
}
kernel void kernel_soft_max_4(
device const float * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
float4 lmax4 = psrc4[tpitg[0]];
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
lmax4 = fmax(lmax4, psrc4[i00]);
}
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
const float max = simd_max(lmax);
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
const float4 exp_psrc4 = exp(psrc4[i00] - max);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
const float sum = simd_sum(lsum);
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
pdst4[i00] /= sum;
}
}
kernel void kernel_diag_mask_inf(
device const float * src0,
device float * dst,
@ -192,6 +197,33 @@ kernel void kernel_diag_mask_inf(
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
} else {
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
}
}
kernel void kernel_diag_mask_inf_8(
device const float4 * src0,
device float4 * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int & n_past,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i = 2*tpig[0];
dst[i+0] = src0[i+0];
dst[i+1] = src0[i+1];
int64_t i4 = 4*i;
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
const int64_t i00 = i4;
for (int k = 3; k >= 0; --k) {
if (i00 + 4 + k <= n_past + i01) {
break;
}
dst[i+1][k] = -INFINITY;
if (i00 + k > n_past + i01) {
dst[i][k] = -INFINITY;
}
}
}
@ -616,6 +648,49 @@ kernel void kernel_mul_mat_f16_f32(
}
}
// Assumes row size (ne00) is a multiple of 4
kernel void kernel_mul_mat_f16_f32_l4(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
const int nrows = ne11;
const int64_t r0 = tgpig.x;
const int64_t im = tgpig.z;
device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
for (int r1 = 0; r1 < nrows; ++r1) {
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
kernel void kernel_alibi_f32(
device const float * src0,
device float * dst,
@ -1800,29 +1875,34 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const half d = il ? (xb->d / 16.h) : xb->d;
const half m = il ? ( -8.h * 16.h) : -8.h;
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = il ? 0xF000 : 0x0F00;
const ushort mask1 = mask0 << 8;
for (int i=0;i<8;i++) {
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
}
}
template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const half d = il ? (xb->d / 16.h) : xb->d;
const half m = xb->m;
const float d1 = il ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = il ? 0xF000 : 0x0F00;
const ushort mask1 = mask0 << 8;
for (int i=0;i<8;i++) {
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
}
}
@ -1858,7 +1938,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
template <typename type4x4>
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
const float d_all = (float)(xb->d);
const half d_all = xb->d;
device const uint8_t * q = (device const uint8_t *)xb->qs;
device const uint8_t * h = (device const uint8_t *)xb->hmask;
device const int8_t * scales = (device const int8_t *)xb->scales;
@ -1871,17 +1951,20 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
((il/4)>0 ? 12 : 3);
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
(scale_2&kmask2) | ((scale_1&kmask1) << 4);
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
const half ml = 4.h * dl;
il = (il/2)%4;
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
il = (il/2) & 3;
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
dl *= coef;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
}
#else
float kcoef = il&1 ? 1.f/16.f : 1.f;
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
@ -1895,31 +1978,37 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
#endif
}
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
}
template <typename type4x4>
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
device const uint8_t * q = xb->qs;
device const uchar * q = xb->qs;
#if QK_K == 256
const float d = (float)(xb->d);
const float min = (float)(xb->dmin);
short is = (il/4) * 2;
q = q + (il/4) * 32 + 16 * (il&1);
il = il%4;
const uchar4 sc = get_scale_min_k4(is, xb->scales);
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
const float ml = il<2 ? min * sc[1] : min * sc[3];
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const half d = il < 2 ? xb->d : xb->d / 16.h;
const half min = xb->dmin;
const half dl = d * sc[0];
const half ml = min * sc[1];
#else
q = q + 16 * (il&1);
device const uint8_t * s = xb->scales;
device const half2 * dh = (device const half2 *)xb->d;
const float2 d = (float2)dh[0];
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
#endif
const ushort mask = il<2 ? 0x0F : 0xF0;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
}
template <typename type4x4>
@ -1928,19 +2017,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
device const uint8_t * qh = xb->qh;
#if QK_K == 256
const float d = (float)(xb->d);
const float min = (float)(xb->dmin);
short is = (il/4) * 2;
q = q + 32 * (il/4) + 16 * (il&1);
qh = qh + 16 * (il&1);
uint8_t ul = 1 << (il/2);
il = il%4;
const uchar4 sc = get_scale_min_k4(is, xb->scales);
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
const float ml = il<2 ? min * sc[1] : min * sc[3];
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const half d = il < 2 ? xb->d : xb->d / 16.h;
const half min = xb->dmin;
const half dl = d * sc[0];
const half ml = min * sc[1];
const ushort mask = il<2 ? 0x0F : 0xF0;
const float qh_val = il<2 ? 16.f : 256.f;
const ushort mask = il<2 ? 0x0F : 0xF0;
const half qh_val = il<2 ? 16.h : 256.h;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
}
@ -1959,7 +2048,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
template <typename type4x4>
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
const float d_all = (float)(xb->d);
const half d_all = xb->d;
device const uint8_t * ql = (device const uint8_t *)xb->ql;
device const uint8_t * qh = (device const uint8_t *)xb->qh;
device const int8_t * scales = (device const int8_t *)xb->scales;
@ -1967,19 +2056,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
#if QK_K == 256
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
qh = qh + 32*(il/8) + 16*(il&1);
float sc = scales[(il%2) + 2 * ((il/2))];
il = (il/2)%4;
half sc = scales[(il%2) + 2 * ((il/2))];
il = (il/2) & 3;
#else
ql = ql + 16 * (il&1);
float sc = scales[il];
half sc = scales[il];
#endif
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
const half coef = il>1 ? 1.f/16.h : 1.h;
const half ml = d_all * sc * 32.h;
const half dl = d_all * sc * coef;
for (int i = 0; i < 16; ++i) {
uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
const float coef = il>1 ? 1.f/16.f : 1.f;
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
reg[i/4][i%4] = d_all * sc * q * coef;
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
reg[i/4][i%4] = dl * q - ml;
}
}