metal : add gqa8 kernel to allow llama-2-70B on metal (#2459)

* Added gqa8 kernel to allow llama-2-70B on metal

* Update ggml-metal.m

Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>

* Extend kernel_mul_mat_f16_f32 to handle gqa broadcast

* Added ne03==ne13 assertion

---------

Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>
This commit is contained in:
Matteo Boschini 2023-08-01 09:43:12 +02:00 committed by GitHub
parent 49e7cb5bb1
commit 1873ff586b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 17 deletions

View file

@ -718,7 +718,8 @@ void ggml_metal_graph_compute(
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne00 == ne10);
GGML_ASSERT(ne02 == ne12); // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
GGML_ASSERT(ne03 == ne13);
if (ggml_is_contiguous(src0) && if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) && ggml_is_contiguous(src1) &&
@ -746,11 +747,11 @@ void ggml_metal_graph_compute(
initWithDevice:ctx->device transposeLeft:false transposeRight:true initWithDevice:ctx->device transposeLeft:false transposeRight:true
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0]; resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
// we need to do ne02 multiplications // we need to do ne12 multiplications
// TODO: is there a way to do this in parallel - currently very slow .. // TODO: is there a way to do this in parallel - currently very slow ..
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
for (int64_t i02 = 0; i02 < ne02; ++i02) { for (int64_t i02 = 0; i02 < ne12; ++i02) {
size_t offs_src0_cur = offs_src0 + i02*nb02; size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
size_t offs_src1_cur = offs_src1 + i02*nb12; size_t offs_src1_cur = offs_src1 + i02*nb12;
size_t offs_dst_cur = offs_dst + i02*nb2; size_t offs_dst_cur = offs_dst + i02*nb2;
@ -772,8 +773,6 @@ void ggml_metal_graph_compute(
switch (src0t) { switch (src0t) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
GGML_ASSERT(ne02 == ne12);
nth0 = 64; nth0 = 64;
nth1 = 1; nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
@ -853,16 +852,18 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6]; [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9]; [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {

View file

@ -509,11 +509,13 @@ kernel void kernel_mul_mat_f16_f32(
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00, constant uint64_t & nb00,
constant uint64_t & nb01, constant uint64_t & nb01,
constant uint64_t & nb02, constant uint64_t & nb02,
constant int64_t & ne10, constant int64_t & ne10,
constant int64_t & ne11, constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10, constant uint64_t & nb10,
constant uint64_t & nb11, constant uint64_t & nb11,
constant uint64_t & nb12, constant uint64_t & nb12,
@ -529,7 +531,7 @@ kernel void kernel_mul_mat_f16_f32(
const int64_t r1 = tgpig.y; const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z; const int64_t im = tgpig.z;
device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02); device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
sum[tpitg.x] = 0.0f; sum[tpitg.x] = 0.0f;
@ -552,6 +554,7 @@ kernel void kernel_mul_mat_f16_f32(
} }
} }
kernel void kernel_alibi_f32( kernel void kernel_alibi_f32(
device const float * src0, device const float * src0,
device float * dst, device float * dst,