diff --git a/ggml-metal.m b/ggml-metal.m index 5c9ecd76e..167ebd467 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -50,12 +50,14 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); + GGML_METAL_DECL_KERNEL(get_rows_q4_1); GGML_METAL_DECL_KERNEL(get_rows_q2_k); GGML_METAL_DECL_KERNEL(get_rows_q4_k); GGML_METAL_DECL_KERNEL(get_rows_q6_k); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); @@ -141,12 +143,14 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); + GGML_METAL_ADD_KERNEL(get_rows_q4_1); GGML_METAL_ADD_KERNEL(get_rows_q2_k); GGML_METAL_ADD_KERNEL(get_rows_q4_k); GGML_METAL_ADD_KERNEL(get_rows_q6_k); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); @@ -545,6 +549,15 @@ void ggml_metal_graph_compute( nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; } break; + case GGML_TYPE_Q4_1: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32]; + } break; case GGML_TYPE_Q2_K: { GGML_ASSERT(ne02 == 1); @@ -596,7 +609,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - if (src0t == GGML_TYPE_Q4_0) { + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q2_K) { @@ -623,6 +636,7 @@ void ggml_metal_graph_compute( switch (src0->type) { case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; + case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break; case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; diff --git a/ggml-metal.metal b/ggml-metal.metal index c94ef83f9..ccd36386b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -11,6 +11,13 @@ typedef struct { uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; +#define QK4_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; + static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) { const int qk = QK4_0; @@ -31,6 +38,27 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i } } +static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) { + const int qk = QK4_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const half d = x[i].d; + const half m = x[i].m; + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F); + const int x1 = (x[i].qs[j] >> 4); + + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; + } + } +} + kernel void kernel_add( device const float * src0, device const float * src1, @@ -212,6 +240,22 @@ kernel void kernel_get_rows_q4_0( (device float *) ((device char *) dst + i*nb1), ne00); } +kernel void kernel_get_rows_q4_1( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q4_1( + (device const block_q4_1 *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + kernel void kernel_rms_norm( device const void * src0, device float * dst, @@ -350,6 +394,85 @@ kernel void kernel_mul_mat_q4_0_f32( //} } +kernel void kernel_mul_mat_q4_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpig[[thread_position_in_grid]], + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + const int nb = ne00/QK4_1; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb; + device const float * y = (device const float *) src1 + r1*ne10; + + const uint nth = tptg.x*tptg.y; + const uint ith = tptg.y*tpitg.x + tpitg.y; + + const int ix = tpitg.y/4; // 0 or 1 + const int iy = tpitg.y - 4*ix; // 0...3 + + const int first = 4 * iy; + + float sumf = 0; + + for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) { + + const float d = (float)x[i].d; + const float m = (float)x[i].m; + + device const uint8_t * xl = x[i].qs + first; + device const float * yl = y + i * QK4_1 + first; + + float2 acc = {0.0f, 0.0f}; + + for (int j = 0; j < 4; ++j) { + + acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m); + acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m); + + } + + sumf += acc[0] + acc[1]; + } + + sum[ith] = sumf; + + // + // Accumulate the sum from all threads in the threadgroup + // + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + dst[r1*ne0 + r0] = sum[0]; + } +} + kernel void kernel_mul_mat_f16_f32( device const char * src0, device const char * src1,