ggml : sync latest llama.cpp (view_src + alloc improvements) (#1247)

* ggml : sync latest llama.cpp (view_src + alloc improvements)

* ggml : fix build
pull/1250/head
Georgi Gerganov 2023-09-05 20:57:27 +03:00 committed by GitHub
parent ba3c333611
commit c3f319d7c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 865 additions and 529 deletions

View File

@ -81,12 +81,29 @@
#if defined(GGML_USE_HIPBLAS)
#define __CUDA_ARCH__ 1300
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
#if __has_builtin(__builtin_elementwise_sub_sat)
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
return reinterpret_cast<const int&>(c);
#else
int8x4_t c;
int16_t tmp;
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp = va[i] - vb[i];
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
c[i] = tmp;
}
return reinterpret_cast<int&>(c);
#endif // __has_builtin(__builtin_elementwise_sub_sat)
}
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
@ -447,58 +464,91 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
}
template <int block_size>
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const float eps = 1e-5f;
float mean = 0.0f;
float var = 0.0f;
float2 mean_var = make_float2(0.f, 0.f);
for (int col = tid; col < ncols; col += WARP_SIZE) {
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
mean += xi;
var += xi * xi;
mean_var.x += xi;
mean_var.y += xi * xi;
}
// sum up partial sums
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
var += __shfl_xor_sync(0xffffffff, var, mask, 32);
mean_var = warp_reduce_sum(mean_var);
if (block_size > WARP_SIZE) {
__shared__ float2 s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = mean_var;
}
__syncthreads();
mean_var = s_sum[lane_id];
mean_var = warp_reduce_sum(mean_var);
}
mean /= ncols;
var = var / ncols - mean * mean;
const float inv_var = rsqrtf(var + eps);
const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
const float inv_std = rsqrtf(var + eps);
for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
}
}
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
}
template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += WARP_SIZE) {
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
tmp += xi * xi;
}
// sum up partial sums
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += WARP_SIZE) {
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * x[row*ncols + col];
}
}
@ -4186,14 +4236,24 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
} else {
const dim3 block_dims(1024, 1, 1);
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
}
}
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
}
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
@ -5721,7 +5781,6 @@ inline void ggml_cuda_op_alibi(
(void) src1;
(void) src0_ddq_i;
(void) src1_ddf_i;
(void) i02;
(void) i1;
}

View File

@ -11,6 +11,7 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
// TODO: temporary - reuse llama.cpp logging
#ifdef GGML_METAL_NDEBUG
#define metal_printf(...)
#else
@ -75,6 +76,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@ -113,12 +115,26 @@ static NSString * const msl_library_source = @"see metal.metal";
@end
struct ggml_metal_context * ggml_metal_init(int n_cb) {
fprintf(stderr, "%s: allocating\n", __func__);
metal_printf("%s: allocating\n", __func__);
// Show all the Metal device instances in the system
NSArray * devices = MTLCopyAllDevices();
id <MTLDevice> device;
NSString * s;
for (device in devices) {
s = [device name];
metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
}
// Pick and show default Metal device
device = MTLCreateSystemDefaultDevice();
s = [device name];
metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
// Configure context
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
ctx->device = device;
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
ctx->device = MTLCreateSystemDefaultDevice();
ctx->queue = [ctx->device newCommandQueue];
ctx->n_buffers = 0;
ctx->concur_list_len = 0;
@ -132,7 +148,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
}
}
@ -146,11 +162,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
}
@ -162,7 +178,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
#endif
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
}
}
@ -174,11 +190,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
#define GGML_METAL_ADD_KERNEL(name) \
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
(int) ctx->pipeline_##name.threadExecutionWidth); \
if (error) { \
fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
return NULL; \
}
@ -204,6 +220,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@ -230,19 +247,19 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
#undef GGML_METAL_ADD_KERNEL
}
fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
if (ctx->device.maxTransferRate != 0) {
fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
} else {
fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
}
return ctx;
}
void ggml_metal_free(struct ggml_metal_context * ctx) {
fprintf(stderr, "%s: deallocating\n", __func__);
metal_printf("%s: deallocating\n", __func__);
#define GGML_METAL_DEL_KERNEL(name) \
[ctx->function_##name release]; \
[ctx->pipeline_##name release];
@ -269,6 +286,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@ -311,7 +329,7 @@ void * ggml_metal_host_malloc(size_t n) {
void * data = NULL;
const int result = posix_memalign((void **) &data, getpagesize(), n);
if (result != 0) {
fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
metal_printf("%s: error: posix_memalign failed\n", __func__);
return NULL;
}
@ -339,7 +357,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
// Metal buffer based on the host memory pointer
//
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
//fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
//metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
const int64_t tsize = ggml_nbytes(t);
@ -350,13 +368,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
*offs = (size_t) ioffs;
//fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
//metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
return ctx->buffers[i].metal;
}
}
fprintf(stderr, "%s: error: buffer is nil\n", __func__);
metal_printf("%s: error: buffer is nil\n", __func__);
return nil;
}
@ -368,7 +386,7 @@ bool ggml_metal_add_buffer(
size_t size,
size_t max_size) {
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
fprintf(stderr, "%s: too many buffers\n", __func__);
metal_printf("%s: too many buffers\n", __func__);
return false;
}
@ -378,7 +396,7 @@ bool ggml_metal_add_buffer(
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
return false;
}
}
@ -399,11 +417,11 @@ bool ggml_metal_add_buffer(
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
if (ctx->buffers[ctx->n_buffers].metal == nil) {
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
return false;
}
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
++ctx->n_buffers;
} else {
@ -423,27 +441,27 @@ bool ggml_metal_add_buffer(
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
if (ctx->buffers[ctx->n_buffers].metal == nil) {
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
return false;
}
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
if (i + size_step < size) {
fprintf(stderr, "\n");
metal_printf("\n");
}
++ctx->n_buffers;
}
}
fprintf(stderr, ", (%8.2f / %8.2f)",
metal_printf(", (%8.2f / %8.2f)",
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
} else {
fprintf(stderr, "\n");
metal_printf("\n");
}
}
@ -453,8 +471,6 @@ bool ggml_metal_add_buffer(
void ggml_metal_set_tensor(
struct ggml_metal_context * ctx,
struct ggml_tensor * t) {
metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
size_t offs;
id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
@ -464,8 +480,6 @@ void ggml_metal_set_tensor(
void ggml_metal_get_tensor(
struct ggml_metal_context * ctx,
struct ggml_tensor * t) {
metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
size_t offs;
id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
@ -560,15 +574,13 @@ void ggml_metal_graph_find_concurrency(
}
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
}
}
void ggml_metal_graph_compute(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
metal_printf("%s: evaluating graph\n", __func__);
@autoreleasepool {
// if there is ctx->concur_list, dispatch concurrently
@ -616,7 +628,7 @@ void ggml_metal_graph_compute(
continue;
}
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
//metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
@ -685,6 +697,12 @@ void ggml_metal_graph_compute(
} break;
case GGML_OP_ADD:
{
GGML_ASSERT(ggml_is_contiguous(src0));
// utilize float4
GGML_ASSERT(ne00 % 4 == 0);
const int64_t nb = ne00/4;
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_add_row];
@ -694,14 +712,20 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
const int64_t n = ggml_nelements(dst);
const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_MUL:
{
GGML_ASSERT(ggml_is_contiguous(src0));
// utilize float4
GGML_ASSERT(ne00 % 4 == 0);
const int64_t nb = ne00/4;
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_mul_row];
@ -711,9 +735,9 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
const int64_t n = ggml_nelements(dst);
const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
@ -764,7 +788,7 @@ void ggml_metal_graph_compute(
} break;
default:
{
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_ASSERT(false);
}
} break;
@ -845,9 +869,13 @@ void ggml_metal_graph_compute(
switch (src0t) {
case GGML_TYPE_F16:
{
nth0 = 64;
nth0 = 32;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
} else {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
}
} break;
case GGML_TYPE_Q4_0:
{
@ -899,8 +927,8 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 2;
nth1 = 32;
nth0 = 4; //1;
nth1 = 8; //32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
} break;
case GGML_TYPE_Q5_K:
@ -923,7 +951,7 @@ void ggml_metal_graph_compute(
} break;
default:
{
fprintf(stderr, "Asserting on type %d\n",(int)src0t);
metal_printf("Asserting on type %d\n",(int)src0t);
GGML_ASSERT(false && "not implemented");
}
};
@ -948,9 +976,12 @@ void ggml_metal_graph_compute(
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@ -964,8 +995,8 @@ void ggml_metal_graph_compute(
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
int64_t ny = (ne11 + 3)/4;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
} break;
@ -1161,7 +1192,7 @@ void ggml_metal_graph_compute(
} break;
default:
{
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_ASSERT(false);
}
}
@ -1186,7 +1217,7 @@ void ggml_metal_graph_compute(
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
if (status != MTLCommandBufferStatusCompleted) {
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
GGML_ASSERT(false);
}
}

View File

@ -25,9 +25,9 @@ typedef struct {
} block_q8_0;
kernel void kernel_add(
device const float * src0,
device const float * src1,
device float * dst,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig];
}
@ -35,18 +35,18 @@ kernel void kernel_add(
// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_add_row(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
constant int64_t & nb,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % ne00];
dst[tpig] = src0[tpig] + src1[tpig % nb];
}
kernel void kernel_mul(
device const float * src0,
device const float * src1,
device float * dst,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig];
}
@ -54,12 +54,12 @@ kernel void kernel_mul(
// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_mul_row(
device const float * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
constant int64_t & nb,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig % ne00];
dst[tpig] = src0[tpig] * src1[tpig % nb];
}
kernel void kernel_scale(
@ -133,19 +133,24 @@ kernel void kernel_soft_max(
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// broadcast
if (tpitg[0] == 0) {
buf[0] = buf[0];
}
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
// the loop, and when that is done, buf[0] has the correct (synchronized) value
//if (tpitg[0] == 0) {
// buf[0] = buf[0];
//}
threadgroup_barrier(mem_flags::mem_threadgroup);
//threadgroup_barrier(mem_flags::mem_threadgroup);
const float max = buf[0];
// parallel sum
buf[tpitg[0]] = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
buf[tpitg[0]] += exp(psrc0[i00] - max);
const float exp_psrc0 = exp(psrc0[i00] - max);
buf[tpitg[0]] += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// whish to compute it twice.
pdst[i00] = exp_psrc0;
}
// reduce
@ -157,17 +162,18 @@ kernel void kernel_soft_max(
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// broadcast
if (tpitg[0] == 0) {
buf[0] = buf[0];
}
// broadcast - not needed, see above
//// broadcast
//if (tpitg[0] == 0) {
// buf[0] = buf[0];
//}
threadgroup_barrier(mem_flags::mem_threadgroup);
//threadgroup_barrier(mem_flags::mem_threadgroup);
const float sum = buf[0];
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
pdst[i00] = exp(psrc0[i00] - max) / sum;
pdst[i00] /= sum;
}
}
@ -214,25 +220,27 @@ kernel void kernel_norm(
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// broadcast
if (tpitg == 0) {
sum[0] /= ne00;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
//// broadcast
//if (tpitg == 0) {
// sum[0] /= ne00;
//}
//threadgroup_barrier(mem_flags::mem_threadgroup);
const float mean = sum[0];
// recenter
// recenter and VARIANCE
device float * y = dst + tgpig*ne00;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
y[i00] = x[i00] - mean;
}
// VARIANCE
// parallel sum
sum[tpitg] = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
y[i00] = x[i00] - mean;
sum[tpitg] += y[i00] * y[i00];
}
//// VARIANCE
//// parallel sum
//sum[tpitg] = 0.0f;
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
// sum[tpitg] += y[i00] * y[i00];
//}
// reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg/2; i > 0; i /= 2) {
@ -241,11 +249,11 @@ kernel void kernel_norm(
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// broadcast
if (tpitg == 0) {
sum[0] /= ne00;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
//// broadcast
//if (tpitg == 0) {
// sum[0] /= ne00;
//}
//threadgroup_barrier(mem_flags::mem_threadgroup);
const float variance = sum[0];
const float scale = 1.0f/sqrt(variance + eps);
@ -435,6 +443,8 @@ kernel void kernel_mul_mat_q4_1_f32(
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}
#define NB_Q8_0 8
kernel void kernel_mul_mat_q8_0_f32(
device const void * src0,
device const float * src1,
@ -463,30 +473,30 @@ kernel void kernel_mul_mat_q8_0_f32(
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[16];
float yl[NB_Q8_0];
float sumf[nr]={0.f};
const int ix = tiisg/2;
const int il = tiisg%2;
const int ix = tiisg/4;
const int il = tiisg%4;
device const float * yb = y + ix * QK8_0 + 16*il;
device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
// each thread in a SIMD group deals with half a block.
for (int ib = ix; ib < nb; ib += nw/2) {
for (int i = 0; i < 16; ++i) {
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
for (int ib = ix; ib < nb; ib += nw/4) {
for (int i = 0; i < NB_Q8_0; ++i) {
yl[i] = yb[i];
}
for (int row = 0; row < nr; row++) {
device const int8_t * qs = x[ib+row*nb].qs + 16*il;
device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
float sumq = 0.f;
for (int iq = 0; iq < 16; ++iq) {
for (int iq = 0; iq < NB_Q8_0; ++iq) {
sumq += qs[iq] * yl[iq];
}
sumf[row] += sumq*x[ib+row*nb].d;
}
yb += QK8_0 * 16;
yb += NB_Q8_0 * nw;
}
for (int row = 0; row < nr; ++row) {
@ -497,6 +507,60 @@ kernel void kernel_mul_mat_q8_0_f32(
}
}
kernel void kernel_mul_mat_f16_f32_1row(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
if (ne00 < 128) {
for (int i = tiisg; i < ne00; i += 32) {
sumf += (float) x[i] * (float) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
} else {
device const half4 * x4 = (device const half4 *) x;
device const float4 * y4 = (device const float4 *) y;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
#define N_F16_F32 4
kernel void kernel_mul_mat_f16_f32(
device const char * src0,
device const char * src1,
@ -515,37 +579,58 @@ kernel void kernel_mul_mat_f16_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
threadgroup float * sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpig[[thread_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 tptg[[threads_per_threadgroup]]) {
uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int64_t rb = tgpig.y*N_F16_F32;
const int64_t im = tgpig.z;
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
sum[tpitg.x] = 0.0f;
if (ne00 < 128) {
for (int row = 0; row < N_F16_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
for (int i = tpitg.x; i < ne00; i += tptg.x) {
sum[tpitg.x] += (float) x[i] * (float) y[i];
}
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
// accumulate the sum from all threads in the threadgroup
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = tptg.x/2; i > 0; i /= 2) {
if (tpitg.x < i) {
sum[tpitg.x] += sum[tpitg.x + i];
float sumf = 0;
for (int i = tiisg; i < ne00; i += 32) {
sumf += (float) x[i] * (float) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
} else {
device const half4 * x4 = (device const half4 *)x;
for (int row = 0; row < N_F16_F32; ++row) {
int r1 = rb + row;
if (r1 >= ne11) {
break;
}
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
device const float4 * y4 = (device const float4 *) y;
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (tpitg.x == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
}
}
kernel void kernel_alibi_f32(
@ -1244,7 +1329,8 @@ kernel void kernel_mul_mat_q4_K_f32(
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int first_row = r0 * N_DST;
const int ib_row = first_row * nb;
const uint offset0 = r2/gqa*(nb*ne0);
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;

View File

@ -1334,7 +1334,7 @@ void ggml_cl_free_data(const struct ggml_tensor* tensor) {
return;
}
cl_mem mem = (cl_mem)tensor->data;
cl_mem mem = (cl_mem)tensor->extra;
clReleaseMemObject(mem);
}
@ -1393,7 +1393,7 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
size_t d_size;
cl_mem d_X = ggml_cl_pool_malloc(ne0 * sizeof(float), &x_size); // src0
cl_mem d_Y = (cl_mem) src1->data; // src1 is already on device, broadcasted.
cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted.
cl_mem d_D = ggml_cl_pool_malloc(ne0 * sizeof(float), &d_size); // dst
@ -1491,9 +1491,9 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
size_t d_size;
cl_mem d_X;
if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
d_X = (cl_mem) src0->data;
d_X = (cl_mem) src0->extra;
} else {
d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size);
d_X = ggml_cl_pool_malloc(sizeof(float) * x_ne, &x_size);
}
cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
@ -1567,7 +1567,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
size_t d_size;
cl_mem d_X;
if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
d_X = (cl_mem) src0->data;
d_X = (cl_mem) src0->extra;
} else {
d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size);
}
@ -1697,7 +1697,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
events.emplace_back();
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
} else if (src0->backend == GGML_BACKEND_GPU) {
d_Q = (cl_mem) src0->data;
d_Q = (cl_mem) src0->extra;
} else {
GGML_ASSERT(false);
}
@ -1860,6 +1860,6 @@ void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) {
CL_CHECK(clFinish(queue));
tensor->data = dst;
tensor->extra = dst;
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
}

868
ggml.c

File diff suppressed because it is too large Load Diff

34
ggml.h
View File

@ -479,6 +479,9 @@ extern "C" {
int64_t perf_cycles;
int64_t perf_time_us;
struct ggml_tensor * view_src;
size_t view_offs;
void * data;
char name[GGML_MAX_NAME];
@ -661,7 +664,7 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
@ -952,11 +955,11 @@ extern "C" {
// a - x
// b - dy
// TODO: update with configurable eps
GGML_API struct ggml_tensor * ggml_rms_norm_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * b,
float eps);
// A: n columns, m rows
// B: n columns, p rows (i.e. we transpose it internally)
@ -1612,7 +1615,8 @@ extern "C" {
struct ggml_tensor * tensor);
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
@ -1677,6 +1681,8 @@ extern "C" {
GGML_LINESEARCH_INVALID_PARAMETERS,
};
typedef void (*ggml_opt_callback)(void * data, float * sched);
// optimization parameters
//
// see ggml.c (ggml_opt_default_params) for default values
@ -1712,12 +1718,14 @@ extern "C" {
float sched; // schedule multiplier (fixed, decay or warmup)
float decay; // weight decay for AdamW, use 0.0f to disable
int decay_min_ndim; // minimum number of tensor dimension to apply weight decay
float alpha; // learning rate
float beta1;
float beta2;
float eps; // epsilon for numerical stability
float eps_f; // epsilon for convergence test
float eps_g; // epsilon for convergence test
float gclip; // gradient clipping
} adam;
// LBFGS parameters
@ -1745,14 +1753,12 @@ extern "C" {
bool just_initialized;
float loss_before;
float loss_after;
struct {
struct ggml_tensor * x; // view of the parameters
struct ggml_tensor * g1; // gradient
struct ggml_tensor * g2; // gradient squared
struct ggml_tensor * m; // first moment
struct ggml_tensor * v; // second moment
struct ggml_tensor * mh; // first moment hat
struct ggml_tensor * vh; // second moment hat
struct ggml_tensor * pf; // past function values
float fx_best;
float fx_prev;
@ -1789,10 +1795,10 @@ extern "C" {
// initialize optimizer context
GGML_API void ggml_opt_init(
struct ggml_context * ctx,
struct ggml_context * ctx,
struct ggml_opt_context * opt,
struct ggml_opt_params params,
int64_t nx);
struct ggml_opt_params params,
int64_t nx);
// continue optimizing the function defined by the tensor f
GGML_API enum ggml_opt_result ggml_opt_resume(
@ -1806,7 +1812,9 @@ extern "C" {
struct ggml_opt_context * opt,
struct ggml_tensor * f,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb);
struct ggml_cgraph * gb,
ggml_opt_callback callback,
void * callback_data);
//
// quantization