diff --git a/ggml-metal.m b/ggml-metal.m index 969cf7daa..06eb3872e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -63,6 +63,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); + GGML_METAL_DECL_KERNEL(get_rows_q8_0); GGML_METAL_DECL_KERNEL(get_rows_q2_K); GGML_METAL_DECL_KERNEL(get_rows_q3_K); GGML_METAL_DECL_KERNEL(get_rows_q4_K); @@ -73,6 +74,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); @@ -81,6 +83,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); @@ -188,6 +191,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); + GGML_METAL_ADD_KERNEL(get_rows_q8_0); GGML_METAL_ADD_KERNEL(get_rows_q2_K); GGML_METAL_ADD_KERNEL(get_rows_q3_K); GGML_METAL_ADD_KERNEL(get_rows_q4_K); @@ -198,6 +202,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); @@ -205,6 +210,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); @@ -747,9 +753,10 @@ void ggml_metal_graph_compute( ne00%32 == 0 && ne11 > 1) { switch (src0->type) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; + case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break; @@ -800,6 +807,15 @@ void ggml_metal_graph_compute( nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32]; } break; + case GGML_TYPE_Q8_0: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32]; + } break; case GGML_TYPE_Q2_K: { GGML_ASSERT(ne02 == 1); @@ -871,7 +887,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -896,9 +912,10 @@ void ggml_metal_graph_compute( case GGML_OP_GET_ROWS: { switch (src0->type) { - case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; + case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 7bc3fdf37..82e1a0c7a 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -18,6 +18,12 @@ typedef struct { uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; +#define QK8_0 32 +typedef struct { + half d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; + kernel void kernel_add( device const float * src0, device const float * src1, @@ -357,7 +363,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device const int first_row = (r0 * nsg + sgitg) * nr; const uint offset0 = first_row * nb + im/gqa*(nb*ne0); device const block_q_type * x = (device const block_q_type *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; float yl[16]; // src1 vector cache float sumf[nr]={0.f}; @@ -429,6 +435,68 @@ kernel void kernel_mul_mat_q4_1_f32( mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } +kernel void kernel_mul_mat_q8_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const int nr = N_DST; + const int nsg = N_SIMDGROUP; + const int nw = N_SIMDWIDTH; + + const int nb = ne00/QK8_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * nsg + sgitg) * nr; + const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; + float sumf[nr]={0.f}; + + const int ix = tiisg/2; + const int il = tiisg%2; + + device const float * yb = y + ix * QK8_0 + 16*il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + for (int i = 0; i < 16; ++i) { + yl[i] = yb[i]; + } + + for (int row = 0; row < nr; row++) { + device const int8_t * qs = x[ib+row*nb].qs + 16*il; + float sumq = 0.f; + for (int iq = 0; iq < 16; ++iq) { + sumq += qs[iq] * yl[iq]; + } + sumf[row] += sumq*x[ib+row*nb].d; + } + + yb += QK8_0 * 16; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } +} + kernel void kernel_mul_mat_f16_f32( device const char * src0, device const char * src1, @@ -480,7 +548,6 @@ kernel void kernel_mul_mat_f16_f32( } } - kernel void kernel_alibi_f32( device const float * src0, device float * dst, @@ -1621,12 +1688,12 @@ template void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); const half d = il ? (xb->d / 16.h) : xb->d; - const half m = il ? (-8.h * 16.h) : -8.h; + const half m = il ? ( -8.h * 16.h) : -8.h; const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask1 = il ? 0xF000 : 0x0F00; for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d; + reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d; reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d; } } @@ -1640,11 +1707,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg const ushort mask1 = il ? 0xF000 : 0x0F00; for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m; + reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m; reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m; } } +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const half d = xb->d; + + for (int i=0;i<16;i++) { + reg[i/4][i%4] = (qs[i + 16*il] * d); + } +} + template void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { const half d = xb->d; @@ -1947,9 +2024,10 @@ kernel void kernel_mul_mm(device const uchar * src0, typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ constant uint64_t &, constant uint64_t &, uint, uint, uint); -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; @@ -1960,9 +2038,10 @@ typedef void (mat_mm_t)(device const uchar *, device const float *, device float constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm;