From 74a6d922f12ccfe16b0c265f43be8978c6f25e98 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Mon, 12 Jun 2023 22:39:21 +0300 Subject: [PATCH] Metal implementation for all k_quants (#1807) * metal : improve q4_K 28.3 -> 26.0 ms/token by avoiding a branch in the calculation of the scales. * metal : small improvement for Q4_K * metal : still optimizing Q4_K This commit pushes it down to 25.3 ms / token. The crazy idea of using 6 bits for the scales is really costly on Metal: if I remove the bit fiddling necessary to make the block scales, time goes almost to the Q4_0 23 ms/token. Before pushing the k-quants upstream I had a Q4_K variant that had used 8-bit scales. It wasn't more accurate, used 0.125 bits more per weight, was running slightly slower on the CPU (due to the larger model size and being memory bound there), and the difference was entirely negligible under CUDA. So, I decided to publish the version with 6-bit scales. Perhaps I should re-consider and change to 8-bit scales? * metal : some more optimizations Q2_K: 25.4 ms/token Q6_K: 27.3 ms/token Q4_0: 22.8 ms/token Q4_1: 23.1 ms/token * metal : Q3_K support Something is not quite right yet. * metal : Q5_K support Initial version achieves 31.2 ms/token, 210 GB/s * metal : still not able to figure out why q3_K does not work * Minor * metal : yet another failed attempt to make q3_K work * metal : optimize Q5_K 31.2 ms -> 27.8 ms. 250 GB/s. * metal : q3_K still not working Adding a heavily commented q3_K metal kernel to explain my obviously faulty logic. Perhaps someone could spot the issue? * metal : q3_K finally working Not optimized at all. What was the issue? The scales are not 4-bytes aligned, and I was accessing them with a uint32_t pointer. When I tried that on CUDA, I got an error (illegal memory access) and added a memcpy to a local array of 3 uint32_t's. But on Metal it told me there is no memcpy, so I tried accessing directly. There is no error, just garbage results. At some point I did try accessing the scales with an uint16_t pointer (the scales are for sure 2-byte aligned), but was still getting garbage. I guess, there must have been another bug. No access to scales is via a uint16_t pointer and, after starting from scratch from the C dequantize function, it finally works. * metal : Q3_K 1st optimization pass * metal : Q3_K second optimization pass - 29.6 ms/token * metal : Q3_K cleanup * metal : fixed accidentally broken Q2_K --------- Co-authored-by: Iwan Kawrakow --- ggml-metal.m | 41 +++- ggml-metal.metal | 547 ++++++++++++++++++++++++++++++++++++----------- llama.cpp | 10 +- 3 files changed, 463 insertions(+), 135 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index b73f51f24..658c392e0 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -52,14 +52,18 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); GGML_METAL_DECL_KERNEL(get_rows_q2_k); + GGML_METAL_DECL_KERNEL(get_rows_q3_k); GGML_METAL_DECL_KERNEL(get_rows_q4_k); + GGML_METAL_DECL_KERNEL(get_rows_q5_k); GGML_METAL_DECL_KERNEL(get_rows_q6_k); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(cpy_f32_f16); @@ -153,14 +157,18 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); GGML_METAL_ADD_KERNEL(get_rows_q2_k); + GGML_METAL_ADD_KERNEL(get_rows_q3_k); GGML_METAL_ADD_KERNEL(get_rows_q4_k); + GGML_METAL_ADD_KERNEL(get_rows_q5_k); GGML_METAL_ADD_KERNEL(get_rows_q6_k); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(cpy_f32_f16); @@ -575,6 +583,15 @@ void ggml_metal_graph_compute( nth1 = 16; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32]; } break; + case GGML_TYPE_Q3_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32]; + } break; case GGML_TYPE_Q4_K: { GGML_ASSERT(ne02 == 1); @@ -584,6 +601,15 @@ void ggml_metal_graph_compute( nth1 = 16; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32]; } break; + case GGML_TYPE_Q5_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32]; + } break; case GGML_TYPE_Q6_K: { GGML_ASSERT(ne02 == 1); @@ -620,15 +646,14 @@ void ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else if (src0t == GGML_TYPE_Q2_K) { + } + else if (src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_Q3_K || + src0t == GGML_TYPE_Q4_K || + src0t == GGML_TYPE_Q5_K || + src0t == GGML_TYPE_Q6_K) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else if (src0t == GGML_TYPE_Q4_K) { - [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else if (src0t == GGML_TYPE_Q6_K) { - [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -646,7 +671,9 @@ void ggml_metal_graph_compute( case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break; case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break; case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index ccd36386b..09e12a879 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -304,34 +304,22 @@ kernel void kernel_mul_mat_q4_0_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, constant int64_t & ne0, - constant int64_t & ne1, threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { const int nb = ne00/QK4_0; - const int8_t m8 = 8; - const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb; device const float * y = (device const float *) src1 + r1*ne10; - const uint nth = tptg.x*tptg.y; - const uint ith = tptg.y*tpitg.x + tpitg.y; + const int nth = tptg.x*tptg.y; + const int ith = tptg.y*tpitg.x + tpitg.y; const int ix = tpitg.y/4; // 0 or 1 const int iy = tpitg.y - 4*ix; // 0...3 @@ -351,47 +339,32 @@ kernel void kernel_mul_mat_q4_0_f32( for (int j = 0; j < 4; ++j) { - acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8); - acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8); + acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4); + acc[1] += yl[j] + yl[j+16]; } - sumf += d * (acc[0] + acc[1]); + sumf += d * (acc[0] - 8.f*acc[1]); } sum[ith] = sumf; // // Accumulate the sum from all threads in the threadgroup - // This version is slightly faster than the commented out one below, - // which I copy-pasted from ggerganov's q4_0 dot product for metal. // threadgroup_barrier(mem_flags::mem_threadgroup); if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + for (uint i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[r1*ne0 + r0] = sum[0]; } - - //// accumulate the sum from all threads in the threadgroup - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (uint i = nth/2; i > 0; i /= 2) { - // if (ith < i) { - // sum[ith] += sum[ith + i]; - // } - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} - - //if (ith == 0) { - // dst[r1*ne0 + r0] = sum[0]; - //} } kernel void kernel_mul_mat_q4_1_f32( @@ -399,20 +372,10 @@ kernel void kernel_mul_mat_q4_1_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, constant int64_t & ne0, - constant int64_t & ne1, threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { const int nb = ne00/QK4_1; @@ -460,11 +423,11 @@ kernel void kernel_mul_mat_q4_1_f32( // threadgroup_barrier(mem_flags::mem_threadgroup); if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { @@ -671,6 +634,15 @@ typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins } block_q2_k; +// 84 bytes / block + +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits + half d; // super-block scale +} block_q3_k; +// 110 bytes / block typedef struct { half d; // super-block scale for quantized scales @@ -678,6 +650,16 @@ typedef struct { uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_k; +// 144 bytes / block + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_k; +// 176 bytes / block typedef struct { uint8_t ql[QK_K/2]; // quants, lower 4 bits @@ -685,16 +667,19 @@ typedef struct { int8_t scales[QK_K/16]; // scales, quantized with 8 bits half d; // super-block scale } block_q6_k; +// 210 bytes / block static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { uchar4 r; if (j < 4) { - r[0] = q[j+0] & 63; r[1] = q[j+4] & 63; - r[2] = q[j+1] & 63; r[3] = q[j+5] & 63; + r[0] = q[j+0] & 63; + r[2] = q[j+1] & 63; + r[1] = q[j+4] & 63; + r[3] = q[j+5] & 63; } else { r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); - r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4); + r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4); } return r; @@ -735,10 +720,65 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i } } +static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + uint16_t aux[8]; + thread const int8_t * scales = (thread const int8_t*)aux; + + for (int i = 0; i < nb; i++) { + + const float d_all = (float)(x[i].d); + + device const uint8_t * q = x[i].qs; + device const uint8_t * h = x[i].hmask; + uint8_t m = 1; + + device const uint16_t * a = (device const uint16_t *)x[i].scales; + aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4); + aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4); + aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4); + aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4); + aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4); + aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4); + aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4); + aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4); + + int is = 0; + float dl; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4)); + } + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4)); + } + + shift += 2; + m <<= 1; + } + q += 32; + } + + } + +} + static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; + for (int i = 0; i < nb; i++) { const float d = x[i].d; @@ -760,6 +800,33 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i } } +static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = (float)(x[i].d); + const float min = (float)(x[i].dmin); + + device const uint8_t * ql = x[i].qs; + device const uint8_t * qh = x[i].qh; + + int is = 0; + uint8_t u1 = 1, u2 = 2; + for (int j = 0; j < QK_K; j += 64) { + const uchar4 sc = get_scale_min_k4(is, x[i].scales); + const float d1 = d * sc[0]; const float m1 = min * sc[1]; + const float d2 = d * sc[2]; const float m2 = min * sc[3]; + for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2; + ql += 32; is += 2; + u1 <<= 2; u2 <<= 2; + } + } + +} + static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -808,6 +875,22 @@ kernel void kernel_get_rows_q2_k( (device float *) ((device char *) dst + i*nb1), ne00); } +kernel void kernel_get_rows_q3_k( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q3_k( + (device const block_q3_k *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + kernel void kernel_get_rows_q4_k( device const void * src0, device const int * src1, @@ -824,6 +907,22 @@ kernel void kernel_get_rows_q4_k( (device float *) ((device char *) dst + i*nb1), ne00); } +kernel void kernel_get_rows_q5_k( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q5_k( + (device const block_q5_k *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + kernel void kernel_get_rows_q6_k( device const void * src0, device const int * src1, @@ -847,20 +946,10 @@ kernel void kernel_mul_mat_q2_k_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, constant int64_t & ne0, - constant int64_t & ne1, threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], // we don't use this for now uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { @@ -875,7 +964,6 @@ kernel void kernel_mul_mat_q2_k_f32( const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; - const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 const int ir = tid%4; // 0...3 @@ -885,35 +973,54 @@ kernel void kernel_mul_mat_q2_k_f32( const int n = 8; const int is = 4*il + (n*ir)/16; + const int y_offset = 64*il + n*ir; + const int q_offset = 32*ip + n*ir; + sum[ith] = 0.0f; float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { - device const uint8_t * q = x[i].qs + 32*ip + n*ir; + device const uint8_t * q = x[i].qs + q_offset; device const uint8_t * scales = x[i].scales + is; uint8_t d1 = scales[0] & 0xF; - uint8_t m1 = scales[0] >> 4; uint8_t d2 = scales[2] & 0xF; + uint8_t m1 = scales[0] >> 4; uint8_t m2 = scales[2] >> 4; - device const float * y = yy + i*QK_K + 64*il + n*ir; + device const float * y = yy + i*QK_K + y_offset; + + //float4 s = {0.f, 0.f, 0.f, 0.f}; + float2 s = {0.f, 0.f}; + float smin = 0; + for (int l = 0; l < n; ++l) { + s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); + s[1] += y[l+32] * ((q[l] >> shift2) & 3); + smin += y[l+ 0] * m1 + y[l+32] * m2; + } const float dall = (float)x[i].d; const float dmin = (float)x[i].dmin; - float4 s = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0]; - s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32]; - } - sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2); - + sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin; } sum[ith] = sumf; + //int mask1 = (ith%4 == 0); + //int mask2 = (ith%16 == 0); + + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i]; + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i]; + //threadgroup_barrier(mem_flags::mem_threadgroup); + //if (ith == 0) { + // for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + // dst[r1*ne0 + r0] = sum[0]; + //} + // // Accumulate the sum from all threads in the threadgroup // This version is slightly faster than the commented out one below, @@ -932,19 +1039,109 @@ kernel void kernel_mul_mat_q2_k_f32( for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[r1*ne0 + r0] = sum[0]; } +} - //// accumulate the sum from all threads in the threadgroup - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (uint i = nth/2; i > 0; i /= 2) { - // if (ith < i) { - // sum[ith] += sum[ith + i]; - // } - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} +kernel void kernel_mul_mat_q3_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne10, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + const uint8_t m3 = 3; + const int8_t m4 = 4; + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + + const int nth = tptg.x*tptg.y; + const int ith = tptg.y*tpitg.x + tpitg.y; + + const int tid = tpitg.y; // expecting 16 + const int ip = tid/8; // 0 or 1 + const int il = tid/2 - 4*ip; // 0...3 + const int ir = tid%2; + const int n = 8; + const int l0 = n*ir; + + const uint8_t m = 1 << (4*ip + il); + + const int shift = 2*il; + + const uint16_t s_shift1 = 4*ip; + const uint16_t s_shift2 = s_shift1 + 2*(il/2); + const int ik = 4 + (il%2); + + const int q_offset = 32*ip + l0; + const int y_offset = 128*ip + 32*il + l0; + + //float sumf = 0; + float sumf1 = 0, sumf2 = 0; + for (int i = tpitg.x; i < nb; i += tptg.x) { + + const float d_all = (float)(x[i].d); + + device const uint8_t * q = x[i].qs + q_offset; + device const uint8_t * h = x[i].hmask + l0; + device const float * y = yy + i * QK_K + y_offset; + + device const uint16_t * a = (device const uint16_t *)x[i].scales; + const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); + + float s = 0; + for (int l = 0; l < n; ++l) { + s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4)); + } + float d = d_all * s; + sumf1 += d * scales[0]; + sumf2 += d; + //sumf += d_all * s * (scales[0] - 32); + + s = 0; + for (int l = 0; l < n; ++l) { + s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4)); + } + d = d_all * s; + sumf1 += d * scales[1]; + sumf2 += d; + //sumf += d_all * s * (scales[1] - 32); + + } + + //sum[ith] = sumf; + sum[ith] = sumf1 - 32.f*sumf2; + + // + // Accumulate the sum from all threads in the threadgroup + // + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + dst[r1*ne0 + r0] = sum[0]; + } - //if (ith == 0) { - // dst[r1*ne0 + r0] = sum[0]; - //} } kernel void kernel_mul_mat_q4_k_f32( @@ -952,23 +1149,17 @@ kernel void kernel_mul_mat_q4_k_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, constant int64_t & ne0, - constant int64_t & ne1, threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], // we don't use this for now uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; @@ -977,37 +1168,55 @@ kernel void kernel_mul_mat_q4_k_f32( device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; - const uint nth = tptg.x*tptg.y; - const uint ith = tptg.y*tpitg.x + tpitg.y; + const int nth = tptg.x*tptg.y; + const int ith = tptg.y*tpitg.x + tpitg.y; const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 - const int ir = tid%4; // 0...3 - const int n = 8; - const int is = 2*il; + const int ir = tid - 4*il;// 0...3 + const int n = 4; + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; sum[ith] = 0.0f; + uchar2 sc1, sc2, sc3, sc4; + float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { - device const uint8_t * q = (x + i)->qs + 32*il + n*ir; - device const float * y = yy + i*QK_K + 64*il + n*ir; - device const uint8_t * scales = (x + i)->scales; + device const uint8_t * q1 = (x + i)->qs + q_offset; + device const uint8_t * q2 = q1 + 64; + device const float * y1 = yy + i*QK_K + y_offset; + device const float * y2 = y1 + 128; const float dall = (float)((x + i)->d); const float dmin = (float)((x + i)->dmin); - const uchar4 sc = get_scale_min_k4(is, scales); + device const uint16_t * a = (device const uint16_t *)(x + i)->scales; + sc1 = as_type((uint16_t)(a[im+0] & kmask1)); + sc2 = as_type((uint16_t)(a[im+2] & kmask1)); + sc3 = as_type((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2))); + sc4 = as_type((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2))); float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; for (int l = 0; l < n; ++l) { - s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0]; - s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32]; + + s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4); + s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4); + smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; + } - sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]); + sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; } + sum[ith] = sumf; // @@ -1043,25 +1252,114 @@ kernel void kernel_mul_mat_q4_k_f32( //} } +kernel void kernel_mul_mat_q5_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne10, + constant int64_t & ne0, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + + const int nth = tptg.x*tptg.y; + const int ith = tptg.y*tpitg.x + tpitg.y; + + const int tid = tpitg.y; // 0...16 + const int il = tid/4; // 0...3 + const int ir = tid - 4*il;// 0...3 + const int n = 4; + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + const uint8_t hm1 = 1u << (2*im); + const uint8_t hm2 = hm1 << 1; + const uint8_t hm3 = hm1 << 4; + const uint8_t hm4 = hm2 << 4; + + uchar2 sc1, sc2, sc3, sc4; + + float sumf = 0; + for (int i = tpitg.x; i < nb; i += tptg.x) { + + device const uint8_t * q1 = (x + i)->qs + q_offset; + device const uint8_t * q2 = q1 + 64; + device const uint8_t * qh = (x + i)->qh + l0; + device const float * y1 = yy + i*QK_K + y_offset; + device const float * y2 = y1 + 128; + + const float dall = (float)((x + i)->d); + const float dmin = (float)((x + i)->dmin); + + device const uint16_t * a = (device const uint16_t *)(x + i)->scales; + sc1 = as_type((uint16_t)(a[im+0] & kmask1)); + sc2 = as_type((uint16_t)(a[im+2] & kmask1)); + sc3 = as_type((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2))); + sc4 = as_type((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2))); + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < n; ++l) { + + s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0)); + s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0)); + s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0)); + s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0)); + smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; + + } + sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; + + } + sum[ith] = sumf; + + // + // Accumulate the sum from all threads in the threadgroup + // + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + dst[r1*ne0 + r0] = sum[0]; + } + +} + kernel void kernel_mul_mat_q6_k_f32( device const void * src0, device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, constant int64_t & ne0, - constant int64_t & ne1, threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], // we don't use this for now uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { @@ -1078,24 +1376,29 @@ kernel void kernel_mul_mat_q6_k_f32( device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; - const uint nth = tptg.x*tptg.y; - const uint ith = tptg.y*tpitg.x + tpitg.y; + const int nth = tptg.x*tptg.y; + const int ith = tptg.y*tpitg.x + tpitg.y; - const int step = QK_K / tptg.y; // we expect this to be 16 - const int iqs = step * tpitg.y; // 0...240 in steps of 16 + // Note: we absolutely assume that tptg.y = 16 and QK_K = 256! + const int iqs = 16 * tpitg.y; const int ip = iqs / 128; // 0 or 1 const int il = (iqs - 128*ip)/16; // 0...7 const int n = 4; - const int is = 8*ip + (n*il)/16; + const int l0 = n*il; + const int is = 8*ip + l0/16; + + const int y_offset = 128*ip + l0; + const int q_offset_l = 64*ip + l0; + const int q_offset_h = 32*ip + l0; float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { - device const uint8_t * ql = x[i].ql + 64*ip + n*il; - device const uint8_t * qh = x[i].qh + 32*ip + n*il; + device const uint8_t * ql = x[i].ql + q_offset_l; + device const uint8_t * qh = x[i].qh + q_offset_h; device const int8_t * sc = x[i].scales + is; - device const float * y = yy + i * QK_K + 128*ip + n*il; + device const float * y = yy + i * QK_K + y_offset; const float dall = x[i].d; diff --git a/llama.cpp b/llama.cpp index a9a7794ae..f0f9124d8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2377,12 +2377,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s printf("size = %8.3f MB\n", tensor.size/1024.0/1024.0); } else { new_type = quantized_type; - // TODO: temporary disabled until Metal / OpenCL support is available - // ref: https://github.com/ggerganov/llama.cpp/issues/1711 - //if (tensor.name == "output.weight") { - // new_type = GGML_TYPE_Q6_K; - //} - if (tensor.name.find("attention.wv.weight") != std::string::npos) { + if (tensor.name == "output.weight") { + new_type = GGML_TYPE_Q6_K; + } + else if (tensor.name.find("attention.wv.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&