From 72ff5282bf0388c60821f504c4c8cc2b1f491aa6 Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Thu, 8 Jun 2023 22:28:21 +0300 Subject: [PATCH] metal : add Q2_K implementation (#1762) * metal : add Q2_K implementation 27.1 ms / token on M2 Max 30-core GPU, so about the same speed as Q4_0. Memory throughput is ~156 GB/s. The access pattern used in the Q2_K CUDA implementation resulted in significantly lower performance (~31 ms/token). * Fixing merge conflicts --------- Co-authored-by: Iwan Kawrakow --- ggml-metal.m | 17 ++++ ggml-metal.metal | 201 ++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 200 insertions(+), 18 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 626ca871c..ac4f1346c 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -49,11 +49,13 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); + GGML_METAL_DECL_KERNEL(get_rows_q2_k); GGML_METAL_DECL_KERNEL(get_rows_q4_k); GGML_METAL_DECL_KERNEL(get_rows_q6_k); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); GGML_METAL_DECL_KERNEL(rope); @@ -137,11 +139,13 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); + GGML_METAL_ADD_KERNEL(get_rows_q2_k); GGML_METAL_ADD_KERNEL(get_rows_q4_k); GGML_METAL_ADD_KERNEL(get_rows_q6_k); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); GGML_METAL_ADD_KERNEL(rope); @@ -525,6 +529,15 @@ void ggml_metal_graph_compute( nth1 = 4; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; } break; + case GGML_TYPE_Q2_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32]; + } break; case GGML_TYPE_Q4_K: { GGML_ASSERT(ne02 == 1); @@ -570,6 +583,9 @@ void ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q2_K) { + [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q4_K) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -591,6 +607,7 @@ void ggml_metal_graph_compute( switch (src0->type) { case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break; case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; default: GGML_ASSERT(false && "not implemented"); diff --git a/ggml-metal.metal b/ggml-metal.metal index e851cbd4d..43814ed09 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -527,6 +527,13 @@ kernel void kernel_cpy_f32_f32( #define QK_K 256 +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins +} block_q2_k; + typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins @@ -555,6 +562,41 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { return r; } +//========================================== dequantization ============================= + +static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = x[i].d; + const float min = x[i].dmin; + + device const uint8_t * q = x[i].qs; + + int is = 0; + float dl, ml; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + uint8_t sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; + + sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; + + shift += 2; + } + q += 32; + } + + } +} + static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -586,12 +628,12 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i for (int i = 0; i < nb; i++) { - const float d = x[i].d; - device const uint8_t * ql = x[i].ql; device const uint8_t * qh = x[i].qh; device const int8_t * sc = x[i].scales; + const float d = x[i].d; + for (int n = 0; n < QK_K; n += 128) { for (int l = 0; l < 32; ++l) { int is = l/16; @@ -612,6 +654,22 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i } } +kernel void kernel_get_rows_q2_k( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q2_k( + (device const block_q2_k *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + kernel void kernel_get_rows_q4_k( device const void * src0, device const int * src1, @@ -628,6 +686,129 @@ kernel void kernel_get_rows_q4_k( (device float *) ((device char *) dst + i*nb1), ne00); } +kernel void kernel_get_rows_q6_k( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q6_k( + (device const block_q6_k *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + +//====================================== dot products ========================= + +kernel void kernel_mul_mat_q2_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpig[[thread_position_in_grid]], // we don't use this for now + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + + const int nth = tptg.x*tptg.y; + const int ith = tptg.y*tpitg.x + tpitg.y; + + + const int tid = tpitg.y; // 0...16 + const int il = tid/4; // 0...3 + const int ir = tid%4; // 0...3 + const int ip = il/2; // 0 or 1 + const int shift1 = 4*(il%2);// 0 or 4 + const int shift2 = shift1+2;// 2 or 6 + const int n = 8; + const int is = 4*il + (n*ir)/16; + + sum[ith] = 0.0f; + + float sumf = 0; + for (int i = tpitg.x; i < nb; i += tptg.x) { + + device const uint8_t * q = x[i].qs + 32*ip + n*ir; + device const uint8_t * scales = x[i].scales + is; + + uint8_t d1 = scales[0] & 0xF; + uint8_t m1 = scales[0] >> 4; + uint8_t d2 = scales[2] & 0xF; + uint8_t m2 = scales[2] >> 4; + + device const float * y = yy + i*QK_K + 64*il + n*ir; + + const float dall = (float)x[i].d; + const float dmin = (float)x[i].dmin; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0]; + s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32]; + } + sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2); + + + } + sum[ith] = sumf; + + // + // Accumulate the sum from all threads in the threadgroup + // This version is slightly faster than the commented out one below, + // which I copy-pasted from ggerganov's q4_0 dot product for metal. + // + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + dst[r1*ne0 + r0] = sum[0]; + } + + //// accumulate the sum from all threads in the threadgroup + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (uint i = nth/2; i > 0; i /= 2) { + // if (ith < i) { + // sum[ith] += sum[ith + i]; + // } + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + //if (ith == 0) { + // dst[r1*ne0 + r0] = sum[0]; + //} +} + kernel void kernel_mul_mat_q4_k_f32( device const void * src0, device const float * src1, @@ -724,22 +905,6 @@ kernel void kernel_mul_mat_q4_k_f32( //} } -kernel void kernel_get_rows_q6_k( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q6_k( - (device const block_q6_k *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - kernel void kernel_mul_mat_q6_k_f32( device const void * src0, device const float * src1,