metal : optimize ggml_mul_mat_id (faster Mixtral PP) (llama/4725)

* ggml : disable fast-math for Metal (cmake build only)

ggml-ci

* metal : fix Metal API debug warnings

* cmake : add -fno-inline for Metal build (llama/4545)

* metal : fix API debug warnings

* metal : fix compile warnings

* metal : use uint64_t for strides

* cmake : rename option to LLAMA_METAL_SHADER_DEBUG

* metal : fix mat-vec Q8_0 kernel for BS > 1

* metal : normalize mat-vec kernel signatures

* cmake : respect LLAMA_QKK_64 option

* metal : fix mat-vec Q4_K kernel for QK_K == 64

* metal : optimizing ggml_mul_mat_id (wip)

* metal : minor fix

* metal : opt mul_mm_id
pull/1725/head
Georgi Gerganov 2024-01-02 21:07:47 +02:00
parent 1e5544b39b
commit f38c057503
2 changed files with 190 additions and 46 deletions

View File

@ -1657,6 +1657,10 @@ void ggml_metal_graph_compute(
}
};
if (ggml_is_quantized(src0t)) {
GGML_ASSERT(ne00 >= nth0*nth1);
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@ -1715,6 +1719,9 @@ void ggml_metal_graph_compute(
// TODO: make this more general
GGML_ASSERT(n_as <= 8);
// max size of the src1ids array in the kernel stack
GGML_ASSERT(ne11 <= 512);
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
const int64_t ne20 = src2 ? src2->ne[0] : 0;
@ -1732,9 +1739,6 @@ void ggml_metal_graph_compute(
GGML_ASSERT(!ggml_is_transposed(src2));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(ne20 % 32 == 0);
// !!!!!!!!! TODO: this assert is probably required but not sure!
//GGML_ASSERT(ne20 >= 64);
GGML_ASSERT(src1t == GGML_TYPE_F32);
const uint r2 = ne12/ne22;
@ -1742,22 +1746,22 @@ void ggml_metal_graph_compute(
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
int ne11_mm_min = 1;
int ne11_mm_min = n_as;
const int idx = ((int32_t *) dst->op_params)[0];
// batch size
GGML_ASSERT(ne01 == ne11);
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
// !!!
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
// indirect matrix multiplication
// !!!
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne20 % 32 == 0 && ne20 >= 64 &&
ne11 > ne11_mm_min) {
switch (src2->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
@ -1787,7 +1791,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
@ -1805,8 +1809,7 @@ void ggml_metal_graph_compute(
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
// TODO: processing one row at a time (ne11 -> 1) is not efficient
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
int nth0 = 32;
int nth1 = 1;
@ -1889,11 +1892,17 @@ void ggml_metal_graph_compute(
} break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
GGML_ASSERT(false && "not implemented");
}
};
if (ggml_is_quantized(src2t)) {
GGML_ASSERT(ne20 >= nth0*nth1);
}
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];

View File

@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
//Note: This is a template, but strictly speaking it only applies to
// quantizations where the block size is 32. It also does not
// giard against the number of rows not being divisible by
// guard against the number of rows not being divisible by
// N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw>
void mul_vec_q_n_f32_impl(
@ -3973,6 +3973,131 @@ void kernel_mul_mm_impl(device const uchar * src0,
}
}
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
void kernel_mul_mm_id_impl(
device const uchar * src0,
device const uchar * src1,
thread short * src1ids,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
int64_t ne1,
constant uint & r2,
constant uint & r3,
threadgroup uchar * shared_memory,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup half * sa = (threadgroup half *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
const uint r0 = tgpig.y;
const uint r1 = tgpig.x;
const uint im = tgpig.z;
if (r1 * BLOCK_SIZE_N >= ne1) return;
// if this block is of 64x32 shape or smaller
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
// a thread shouldn't load data outside of the matrix
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
simdgroup_half8x8 ma[4];
simdgroup_float8x8 mb[2];
simdgroup_float8x8 c_res[8];
for (int i = 0; i < 8; i++){
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
short il = (tiitg % THREAD_PER_ROW);
const uint i12 = im%ne12;
const uint i13 = im/ne12;
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
ushort offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = (device const float *)(src1
+ nb12 * im
+ nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
half4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
}
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2+nl-1)/nl : x;
y += BLOCK_SIZE_K;
threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
for (int i = 0; i < 4; i++) {
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
}
simdgroup_barrier(mem_flags::mem_none);
for (int i = 0; i < 2; i++) {
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
}
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
for (int i = 0; i < 8; i++){
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
}
}
}
{
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
for (int i = 0; i < 8; i++) {
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
if (sgitg == 0) {
for (int i = 0; i < n_rows; i++) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
*(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
}
}
}
}
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm(device const uchar * src0,
device const uchar * src1,
@ -4019,7 +4144,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
kernel void kernel_mul_mm_id(
device const uchar * ids,
device const uchar * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
@ -4048,18 +4173,28 @@ kernel void kernel_mul_mm_id(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
const int64_t bid = tgpig.z/(ne12*ne13);
// expert id
const int32_t id = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
// row indices of src1 for expert id
int64_t _ne1 = 0;
short src1ids[512];
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
src0[id],
src1 + bid*nb11,
(device float *) (dst + bid*nb1),
for (int64_t i1 = 0; i1 < ne1; i1++) {
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
src1ids[_ne1++] = i1;
}
}
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
src0s[id],
src1,
src1ids,
dst,
ne00,
ne02,
nb01,
@ -4069,7 +4204,7 @@ kernel void kernel_mul_mm_id(
nb11,
nb12,
ne0,
ne1,
_ne1,
r2,
r3,
shared_memory,
@ -4158,7 +4293,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
typedef void (mat_mm_id_t)(
device const uchar * ids,
device const uchar * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
@ -4207,7 +4342,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
kernel void kernel_mul_mv_id_f32_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@ -4251,7 +4386,7 @@ kernel void kernel_mul_mv_id_f32_f32(
kernel_mul_mv_f32_f32_impl(
src0[id],
src1 + bid*nb11,
(device float *) (dst + bid*nb1),
dst + bid*ne0,
ne00,
ne01,
ne02,
@ -4276,7 +4411,7 @@ kernel void kernel_mul_mv_id_f32_f32(
kernel void kernel_mul_mv_id_f16_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@ -4320,7 +4455,7 @@ kernel void kernel_mul_mv_id_f16_f32(
kernel_mul_mv_f16_f32_impl(
src0[id],
src1 + bid*nb11,
(device float *) (dst + bid*nb1),
dst + bid*ne0,
ne00,
ne01,
ne02,
@ -4345,7 +4480,7 @@ kernel void kernel_mul_mv_id_f16_f32(
kernel void kernel_mul_mv_id_q8_0_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@ -4389,7 +4524,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
kernel_mul_mv_q8_0_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
(device float *) ( dst + bid*nb1),
dst + bid*ne0,
ne00,
ne01,
ne02,
@ -4408,7 +4543,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
kernel void kernel_mul_mv_id_q4_0_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@ -4452,7 +4587,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
src0[id],
(device const float *) (src1 + bid*nb11),
(device float *) ( dst + bid*nb1),
dst + bid*ne0,
ne00,
ne01,
ne02,
@ -4471,7 +4606,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
kernel void kernel_mul_mv_id_q4_1_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@ -4515,7 +4650,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
src0[id],
(device const float *) (src1 + bid*nb11),
(device float *) ( dst + bid*nb1),
dst + bid*ne0,
ne00,
ne01,
ne02,
@ -4534,7 +4669,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
kernel void kernel_mul_mv_id_q5_0_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@ -4578,7 +4713,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
src0[id],
(device const float *) (src1 + bid*nb11),
(device float *) ( dst + bid*nb1),
dst + bid*ne0,
ne00,
ne01,
ne02,
@ -4597,7 +4732,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
kernel void kernel_mul_mv_id_q5_1_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@ -4641,7 +4776,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
src0[id],
(device const float *) (src1 + bid*nb11),
(device float *) ( dst + bid*nb1),
dst + bid*ne0,
ne00,
ne01,
ne02,
@ -4660,7 +4795,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
kernel void kernel_mul_mv_id_q2_K_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@ -4704,7 +4839,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
kernel_mul_mv_q2_K_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
(device float *) ( dst + bid*nb1),
dst + bid*ne0,
ne00,
ne01,
ne02,
@ -4723,7 +4858,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
kernel void kernel_mul_mv_id_q3_K_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@ -4767,7 +4902,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
kernel_mul_mv_q3_K_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
(device float *) ( dst + bid*nb1),
dst + bid*ne0,
ne00,
ne01,
ne02,
@ -4786,7 +4921,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
kernel void kernel_mul_mv_id_q4_K_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@ -4830,7 +4965,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
kernel_mul_mv_q4_K_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
(device float *) ( dst + bid*nb1),
dst + bid*ne0,
ne00,
ne01,
ne02,
@ -4849,7 +4984,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
kernel void kernel_mul_mv_id_q5_K_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@ -4893,7 +5028,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
kernel_mul_mv_q5_K_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
(device float *) ( dst + bid*nb1),
dst + bid*ne0,
ne00,
ne01,
ne02,
@ -4912,7 +5047,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
kernel void kernel_mul_mv_id_q6_K_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
device float * dst,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
@ -4956,7 +5091,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
kernel_mul_mv_q6_K_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
(device float *) ( dst + bid*nb1),
dst + bid*ne0,
ne00,
ne01,
ne02,