From 0f291e1f65c1d68201e71ce99c89562a36686b6d Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Thu, 8 Jun 2023 19:46:22 +0300 Subject: [PATCH] metal : Q6_K implementation (#1752) * Metal implementation for Q4_K Very slow for now: 42 ms / token, Q4_0 runs in 28 ms/token on my 30-core M2 Max GPU. * Optimizing Q4_K on metal The first token always takes longer, I guess because the metal kernel is being jit-compiled. So, using n = 128 to measure time. At this point Q4_K takes 29.5 ms / token compared to 27.2 ms / token for Q4_0. Quite a bit better than the initial attempt, but still not good enough. * Optimizing q4_K metal dot some more For n = 256 it is now 28.1 ms/token compared to 27 ms/token for q4_0. * Fix after merge with master * Metal implementation for Q6_K Similar to the CUDA implementation. No idea if this is the optimum for Metal, but the few alternative variants I tried all had a lower performance. We get 36.5 ms / token on M2 Max with 30 GPU cores. This corresponds to ~200 GB/second throughput. * clang-tidy : add config back * Much better Q6_K implementation for metal 28.3 ms / token for 7B. Subtracting ~9 ms that is spent in other compute graph operations, we are left with ~19 ms for the matrix multiplications. The model is ~5.5 GB, so we are getting 1000 / 19 * 5.5 = 290 GB/s! --------- Co-authored-by: Iwan Kawrakow --- ggml-metal.m | 17 +++++ ggml-metal.metal | 177 +++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 187 insertions(+), 7 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index f2a637b7a..626ca871c 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -50,10 +50,12 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_k); + GGML_METAL_DECL_KERNEL(get_rows_q6_k); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); @@ -136,10 +138,12 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_k); + GGML_METAL_ADD_KERNEL(get_rows_q6_k); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); @@ -530,6 +534,15 @@ void ggml_metal_graph_compute( nth1 = 16; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32]; } break; + case GGML_TYPE_Q6_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32]; + } break; default: { fprintf(stderr, "Asserting on type %d\n",(int)src0t); @@ -560,6 +573,9 @@ void ggml_metal_graph_compute( } else if (src0t == GGML_TYPE_Q4_K) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q6_K) { + [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -576,6 +592,7 @@ void ggml_metal_graph_compute( case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index cbcd59ad4..e851cbd4d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -303,18 +303,37 @@ kernel void kernel_mul_mat_q4_0_f32( sum[ith] += acc*d; } - // accumulate the sum from all threads in the threadgroup + // + // Accumulate the sum from all threads in the threadgroup + // This version is slightly faster than the commented out one below, + // which I copy-pasted from ggerganov's q4_0 dot product for metal. + // threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = nth/2; i > 0; i /= 2) { - if (ith < i) { - sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; } - + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[r1*ne0 + r0] = sum[0]; } + + //// accumulate the sum from all threads in the threadgroup + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (uint i = nth/2; i > 0; i /= 2) { + // if (ith < i) { + // sum[ith] += sum[ith + i]; + // } + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + //if (ith == 0) { + // dst[r1*ne0 + r0] = sum[0]; + //} } kernel void kernel_mul_mat_f16_f32( @@ -515,6 +534,13 @@ typedef struct { uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_k; +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_k; + static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { uchar4 r; if (j < 4) { @@ -554,6 +580,38 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i } } +static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = x[i].d; + + device const uint8_t * ql = x[i].ql; + device const uint8_t * qh = x[i].qh; + device const int8_t * sc = x[i].scales; + + for (int n = 0; n < QK_K; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l + 0] = d * sc[is + 0] * q1; + y[l + 32] = d * sc[is + 2] * q2; + y[l + 64] = d * sc[is + 4] * q3; + y[l + 96] = d * sc[is + 6] * q4; + } + y += 128; + ql += 64; + qh += 32; + sc += 8; + } + } +} + kernel void kernel_get_rows_q4_k( device const void * src0, device const int * src1, @@ -665,3 +723,108 @@ kernel void kernel_mul_mat_q4_k_f32( // dst[r1*ne0 + r0] = sum[0]; //} } + +kernel void kernel_get_rows_q6_k( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q6_k( + (device const block_q6_k *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + +kernel void kernel_mul_mat_q6_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpig[[thread_position_in_grid]], // we don't use this for now + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + + const uint8_t kmask1 = 0x03; + const uint8_t kmask2 = 0x0C; + const uint8_t kmask3 = 0x30; + const uint8_t kmask4 = 0xC0; + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + + const uint nth = tptg.x*tptg.y; + const uint ith = tptg.y*tpitg.x + tpitg.y; + + const int step = QK_K / tptg.y; // we expect this to be 16 + const int iqs = step * tpitg.y; // 0...240 in steps of 16 + const int ip = iqs / 128; // 0 or 1 + const int il = (iqs - 128*ip)/16; // 0...7 + const int n = 4; + const int is = 8*ip + (n*il)/16; + + float sumf = 0; + for (int i = tpitg.x; i < nb; i += tptg.x) { + + device const uint8_t * ql = x[i].ql + 64*ip + n*il; + device const uint8_t * qh = x[i].qh + 32*ip + n*il; + device const int8_t * sc = x[i].scales + is; + + device const float * y = yy + i * QK_K + 128*ip + n*il; + + const float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + } + + sum[ith] = sumf; + + // + // Accumulate the sum from all threads in the threadgroup + // + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + dst[r1*ne0 + r0] = sum[0]; + } + +}