metal : support bcast add & dup & cont op (#2323)

This commit is contained in:
Jiahao Li 2023-07-23 19:00:37 +08:00 committed by GitHub
parent d2a43664f9
commit 83a00ce69b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 1 deletions

View file

@ -42,6 +42,7 @@ struct ggml_metal_context {
id<MTLComputePipelineState> pipeline_##name id<MTLComputePipelineState> pipeline_##name
GGML_METAL_DECL_KERNEL(add); GGML_METAL_DECL_KERNEL(add);
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
GGML_METAL_DECL_KERNEL(mul); GGML_METAL_DECL_KERNEL(mul);
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
GGML_METAL_DECL_KERNEL(scale); GGML_METAL_DECL_KERNEL(scale);
@ -157,6 +158,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
GGML_METAL_ADD_KERNEL(add); GGML_METAL_ADD_KERNEL(add);
GGML_METAL_ADD_KERNEL(add_row);
GGML_METAL_ADD_KERNEL(mul); GGML_METAL_ADD_KERNEL(mul);
GGML_METAL_ADD_KERNEL(mul_row); GGML_METAL_ADD_KERNEL(mul_row);
GGML_METAL_ADD_KERNEL(scale); GGML_METAL_ADD_KERNEL(scale);
@ -464,10 +466,16 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
[encoder setComputePipelineState:ctx->pipeline_add]; if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_add_row];
} else {
[encoder setComputePipelineState:ctx->pipeline_add];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
const int64_t n = ggml_nelements(dst); const int64_t n = ggml_nelements(dst);
@ -919,7 +927,9 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_DUP:
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_CONT:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];

View file

@ -67,6 +67,17 @@ kernel void kernel_add(
dst[tpig] = src0[tpig] + src1[tpig]; dst[tpig] = src0[tpig] + src1[tpig];
} }
// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_add_row(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % ne00];
}
kernel void kernel_mul( kernel void kernel_mul(
device const float * src0, device const float * src0,
device const float * src1, device const float * src1,