From 4d98d9a65665eee3838cef936641f640e3f5b649 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 13 Dec 2023 21:54:54 +0200 Subject: [PATCH] sync : ggml (SD ops, tests, kernels) (#4444) * sync : ggml (SD ops, tests, kernels) ggml-ci * cuda : restore im2col ggml-ci * metal : fix accuracy of dequantization kernels ggml-ci * cuda : restore correct im2col ggml-ci * metal : try to fix moe test by reducing expert size ggml-ci * cuda : fix bin bcast when src1 and dst have different types ggml-ci --------- Co-authored-by: slaren --- ggml-cuda.cu | 481 +++++++++++++++++++++++++++++++++++-- ggml-metal.m | 265 +++++++++++++++++++- ggml-metal.metal | 296 +++++++++++++++++++---- ggml.c | 183 +++++++++++--- ggml.h | 20 +- tests/test-backend-ops.cpp | 219 ++++++++++++++++- 6 files changed, 1334 insertions(+), 130 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 9e1acd3f1..019648bdd 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -439,6 +439,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 +#define CUDA_TANH_BLOCK_SIZE 256 #define CUDA_RELU_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256 #define CUDA_CPY_BLOCK_SIZE 32 @@ -451,6 +452,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_QUANTIZE_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 #define CUDA_GET_ROWS_BLOCK_SIZE 256 +#define CUDA_UPSCALE_BLOCK_SIZE 256 +#define CUDA_CONCAT_BLOCK_SIZE 256 +#define CUDA_PAD_BLOCK_SIZE 256 +#define CUDA_ACC_BLOCK_SIZE 256 +#define CUDA_IM2COL_BLOCK_SIZE 256 // dmmv = dequantize_mul_mat_vec #ifndef GGML_CUDA_DMMV_X @@ -612,6 +618,24 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); } +static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne, + const int ne10, const int ne11, const int ne12, + const int nb1, const int nb2, int offset) { + const int i = blockDim.x * blockIdx.x + threadIdx.x; + if (i >= ne) { + return; + } + int src1_idx = i - offset; + int oz = src1_idx / nb2; + int oy = (src1_idx - (oz * nb2)) / nb1; + int ox = src1_idx % nb1; + if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) { + dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; + } else { + dst[i] = x[i]; + } +} + static __global__ void gelu_f32(const float * x, float * dst, const int k) { const float GELU_COEF_A = 0.044715f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; @@ -634,6 +658,23 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) { dst[i] = x[i] / (1.0f + expf(-x[i])); } +static __global__ void gelu_quick_f32(const float *x, float *dst, int k) { + const float GELU_QUICK_COEF = -1.702f; + const int i = blockDim.x*blockIdx.x + threadIdx.x; + if (i >= k) { + return; + } + dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i]))); +} + +static __global__ void tanh_f32(const float *x, float *dst, int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + if (i >= k) { + return; + } + dst[i] = tanhf(x[i]); +} + static __global__ void relu_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -643,6 +684,14 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) { dst[i] = fmaxf(x[i], 0); } +static __global__ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + if (i >= k) { + return; + } + dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope; +} + static __global__ void sqr_f32(const float * x, float * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -688,6 +737,132 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c } } +static __global__ void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + // operation + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + if (blockIdx.z < ne02) { // src0 + int offset_src = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + nidx + + blockIdx.y * ne0 + + (blockIdx.z - ne02) * ne0 * gridDim.y; + dst[offset_dst] = y[offset_src]; + } +} + +static __global__ void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor) { + int ne0 = ne00 * scale_factor; + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + // operation + int i00 = nidx / scale_factor; + int i01 = blockIdx.y / scale_factor; + int offset_src = + i00 + + i01 * ne00 + + blockIdx.z * nb02; + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + dst[offset_dst] = x[offset_src]; +} + +static __global__ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + // operation + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) { + int offset_src = + nidx + + blockIdx.y * ne00 + + blockIdx.z * ne00 * ne01; + dst[offset_dst] = x[offset_src]; + } else { + dst[offset_dst] = 0.0f; + } +} + +template +static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) { + int start = blockIdx.x * group_size; + int end = start + group_size; + + start += threadIdx.x; + + if (end >= ne_elements) { + end = ne_elements; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += block_size) { + tmp += x[j]; + } + + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + float mean = tmp / group_size; + tmp = 0.0f; + + for (int j = start; j < end; j += block_size) { + float xi = x[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + float variance = tmp / group_size; + float scale = rsqrtf(variance + eps); + for (int j = start; j < end; j += block_size) { + dst[j] *= scale; + } +} + template static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; @@ -5071,19 +5246,30 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min, static __global__ void im2col_f32_f16( const float * x, half * dst, - int ofs0, int ofs1, int IW, int IH, int CHW, + int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW, int s0, int s1, int p0, int p1, int d0, int d1) { - const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0; - const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1; + const int i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= pelements) { + return; + } + + const int ksize = OW * (KH > 1 ? KW : 1); + const int kx = i / ksize; + const int kd = kx * ksize; + const int ky = (i - kd) / OW; + const int ix = i % OW; + + const int iiw = ix * s0 + kx * d0 - p0; + const int iih = blockIdx.y * s1 + ky * d1 - p1; const int offset_dst = - (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW + - (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z); + (blockIdx.y * OW + ix) * CHW + + (blockIdx.z * (KW * KH) + ky * KW + kx); if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { dst[offset_dst] = __float2half(0.0f); } else { - const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1; + const int offset_src = blockIdx.z * offset_delta; dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]); } } @@ -5220,10 +5406,10 @@ struct bin_bcast_cuda { size_t nb12 = cnb1[2]; size_t nb13 = cnb1[3]; - size_t s0 = nb0 / sizeof(src1_t); - size_t s1 = nb1 / sizeof(src1_t); - size_t s2 = nb2 / sizeof(src1_t); - size_t s3 = nb3 / sizeof(src1_t); + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); size_t s10 = nb10 / sizeof(src1_t); size_t s11 = nb11 / sizeof(src1_t); @@ -5269,6 +5455,13 @@ struct bin_bcast_cuda { } }; +static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements, + const int ne10, const int ne11, const int ne12, + const int nb1, const int nb2, const int offset, cudaStream_t stream) { + int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE; + acc_f32<<>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset); +} + static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; gelu_f32<<>>(x, dst, k); @@ -5279,11 +5472,26 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ silu_f32<<>>(x, dst, k); } +static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; + gelu_quick_f32<<>>(x, dst, k); +} + +static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; + tanh_f32<<>>(x, dst, k); +} + static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; relu_f32<<>>(x, dst, k); } +static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) { + const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; + leaky_relu_f32<<>>(x, dst, k, negative_slope); +} + static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE; sqr_f32<<>>(x, dst, k); @@ -5300,6 +5508,38 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i } } +static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) { + static const float eps = 1e-6f; + if (group_size < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + group_norm_f32<<>>(x, dst, group_size, ne_elements, eps); + } else { + const dim3 block_dims(1024, 1, 1); + group_norm_f32<1024><<>>(x, dst, group_size, ne_elements, eps); + } +} + +static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) { + int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; + dim3 gridDim(num_blocks, ne1, ne2); + concat_f32<<>>(x, y, dst, ne0, ne02); +} + +static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) { + int ne0 = (ne00 * scale_factor); + int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; + dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02); + upscale_f32<<>>(x, dst, ne00, ne00 * ne01, scale_factor); +} + +static void pad_f32_cuda(const float * x, float * dst, + const int ne00, const int ne01, const int ne02, + const int ne0, const int ne1, const int ne2, cudaStream_t stream) { + int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; + dim3 gridDim(num_blocks, ne1, ne2); + pad_f32<<>>(x, dst, ne0, ne00, ne01, ne02); +} + static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { @@ -6262,13 +6502,14 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con soft_max_f32<<>>(x, y, dst, ncols_x, nrows_y, scale); } -static void im2col_f32_f16_cuda(const float * x, half * dst, - int OH, int IW, int IH, int OW, int IC, - int KH, int KW, int N, int ofs0, int ofs1, - int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) { - dim3 block_nums(IC, OH, OW); - dim3 block_dims(N, KH, KW); - im2col_f32_f16<<>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1); +static void im2col_f32_f16_cuda(const float* x, half* dst, + int IW, int IH, int OW, int OH, int KW, int KH, int IC, + int offset_delta, + int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { + const int parallel_elements = OW * KW * KH; + const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; + dim3 block_nums(num_blocks, OH, IC); + im2col_f32_f16<<>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); } // buffer pool for cuda @@ -6615,6 +6856,25 @@ inline void ggml_cuda_op_add( ggml_cuda_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } +inline void ggml_cuda_op_acc( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported + + int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 + int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 + // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused + int offset = dst->op_params[3] / 4; // offset in bytes + + acc_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream); + + (void) dst; +} + inline void ggml_cuda_op_mul( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -6657,6 +6917,34 @@ inline void ggml_cuda_op_silu( (void) src1_dd; } +inline void ggml_cuda_op_gelu_quick( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + gelu_quick_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_tanh( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + tanh_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + inline void ggml_cuda_op_relu( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -6671,6 +6959,23 @@ inline void ggml_cuda_op_relu( (void) src1_dd; } +inline void ggml_cuda_op_leaky_relu( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + float negative_slope; + memcpy(&negative_slope, dst->op_params, sizeof(float)); + + leaky_relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + inline void ggml_cuda_op_sqr( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -6705,6 +7010,71 @@ inline void ggml_cuda_op_norm( (void) src1_dd; } + +inline void ggml_cuda_op_group_norm( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int num_groups = dst->op_params[0]; + int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); + group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_concat( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_cuda(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream); + } + + (void) src1; + (void) dst; +} + +inline void ggml_cuda_op_upscale( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + + const int scale_factor = dst->op_params[0]; + + upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream); + + (void) src1; + (void) dst; +} + +inline void ggml_cuda_op_pad( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + + pad_f32_cuda(src0_dd, dst_dd, + src0->ne[0], src0->ne[1], src0->ne[2], + dst->ne[0], dst->ne[1], dst->ne[2], main_stream); + + (void) src1; + (void) dst; +} + inline void ggml_cuda_op_rms_norm( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -7219,7 +7589,6 @@ inline void ggml_cuda_op_im2col( const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; - const int64_t N = src1->ne[is_2D ? 3 : 2]; const int64_t IC = src1->ne[is_2D ? 2 : 1]; const int64_t IH = is_2D ? src1->ne[1] : 1; const int64_t IW = src1->ne[0]; @@ -7230,17 +7599,15 @@ inline void ggml_cuda_op_im2col( const int64_t OH = is_2D ? dst->ne[2] : 1; const int64_t OW = dst->ne[1]; - const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 - const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 - im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, - OH, IW, IH, OW, IC, KH, KW, N, - ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream); + im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); (void) src0; (void) src0_dd; } + inline void ggml_cuda_op_sum_rows( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -7789,6 +8156,10 @@ static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, gg ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add); } +static void ggml_cuda_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_acc); +} + static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul); } @@ -7805,10 +8176,22 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu); } +static void ggml_cuda_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu_quick); +} + +static void ggml_cuda_tanh(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_tanh); +} + static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu); } +static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu); +} + static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr); } @@ -7817,6 +8200,22 @@ static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, g ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm); } +static void ggml_cuda_group_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_group_norm); +} + +static void ggml_cuda_concat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_concat); +} + +static void ggml_cuda_upscale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_upscale); +} + +static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad); +} + static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm); } @@ -8809,6 +9208,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_ADD: func = ggml_cuda_add; break; + case GGML_OP_ACC: + func = ggml_cuda_acc; + break; case GGML_OP_MUL: func = ggml_cuda_mul; break; @@ -8823,6 +9225,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_UNARY_OP_SILU: func = ggml_cuda_silu; break; + case GGML_UNARY_OP_GELU_QUICK: + func = ggml_cuda_gelu_quick; + break; + case GGML_UNARY_OP_TANH: + func = ggml_cuda_tanh; + break; case GGML_UNARY_OP_RELU: func = ggml_cuda_relu; break; @@ -8833,6 +9241,21 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_NORM: func = ggml_cuda_norm; break; + case GGML_OP_GROUP_NORM: + func = ggml_cuda_group_norm; + break; + case GGML_OP_CONCAT: + func = ggml_cuda_concat; + break; + case GGML_OP_UPSCALE: + func = ggml_cuda_upscale; + break; + case GGML_OP_PAD: + func = ggml_cuda_pad; + break; + case GGML_OP_LEAKY_RELU: + func = ggml_cuda_leaky_relu; + break; case GGML_OP_RMS_NORM: func = ggml_cuda_rms_norm; break; @@ -8855,9 +9278,6 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ func = ggml_cuda_sqr; break; case GGML_OP_CLAMP: - if (!any_on_device) { - return false; - } func = ggml_cuda_clamp; break; case GGML_OP_CPY: @@ -8866,6 +9286,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_CONT: func = ggml_cuda_dup; break; + case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: @@ -9285,6 +9706,8 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_TANH: return true; default: return false; @@ -9369,6 +9792,12 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten case GGML_OP_IM2COL: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: + case GGML_OP_ACC: + case GGML_OP_CONCAT: + case GGML_OP_GROUP_NORM: + case GGML_OP_UPSCALE: + case GGML_OP_PAD: + case GGML_OP_LEAKY_RELU: return true; default: return false; diff --git a/ggml-metal.m b/ggml-metal.m index 1dcfa6edd..465679a6b 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -66,9 +66,11 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(div_row); GGML_METAL_DECL_KERNEL(scale); GGML_METAL_DECL_KERNEL(scale_4); - GGML_METAL_DECL_KERNEL(silu); + GGML_METAL_DECL_KERNEL(tanh); GGML_METAL_DECL_KERNEL(relu); GGML_METAL_DECL_KERNEL(gelu); + GGML_METAL_DECL_KERNEL(gelu_quick); + GGML_METAL_DECL_KERNEL(silu); GGML_METAL_DECL_KERNEL(soft_max); GGML_METAL_DECL_KERNEL(soft_max_4); GGML_METAL_DECL_KERNEL(diag_mask_inf); @@ -86,6 +88,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_q5_K); GGML_METAL_DECL_KERNEL(get_rows_q6_K); GGML_METAL_DECL_KERNEL(rms_norm); + GGML_METAL_DECL_KERNEL(group_norm); GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mv_f32_f32); GGML_METAL_DECL_KERNEL(mul_mv_f16_f16); @@ -145,8 +148,11 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(rope_f16); GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(im2col_f16); + GGML_METAL_DECL_KERNEL(upscale_f32); + GGML_METAL_DECL_KERNEL(pad_f32); GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc); GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc); + GGML_METAL_DECL_KERNEL(leaky_relu_f32); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); GGML_METAL_DECL_KERNEL(cpy_f32_q8_0); @@ -334,9 +340,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(div_row); GGML_METAL_ADD_KERNEL(scale); GGML_METAL_ADD_KERNEL(scale_4); - GGML_METAL_ADD_KERNEL(silu); + GGML_METAL_ADD_KERNEL(tanh); GGML_METAL_ADD_KERNEL(relu); GGML_METAL_ADD_KERNEL(gelu); + GGML_METAL_ADD_KERNEL(gelu_quick); + GGML_METAL_ADD_KERNEL(silu); GGML_METAL_ADD_KERNEL(soft_max); GGML_METAL_ADD_KERNEL(soft_max_4); GGML_METAL_ADD_KERNEL(diag_mask_inf); @@ -354,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(get_rows_q5_K); GGML_METAL_ADD_KERNEL(get_rows_q6_K); GGML_METAL_ADD_KERNEL(rms_norm); + GGML_METAL_ADD_KERNEL(group_norm); GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mv_f32_f32); GGML_METAL_ADD_KERNEL(mul_mv_f16_f16); @@ -415,8 +424,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(rope_f16); GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(im2col_f16); + GGML_METAL_ADD_KERNEL(upscale_f32); + GGML_METAL_ADD_KERNEL(pad_f32); GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc); GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc); + GGML_METAL_ADD_KERNEL(leaky_relu_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); GGML_METAL_ADD_KERNEL(cpy_f32_q8_0); @@ -450,9 +462,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(div_row); GGML_METAL_DEL_KERNEL(scale); GGML_METAL_DEL_KERNEL(scale_4); - GGML_METAL_DEL_KERNEL(silu); + GGML_METAL_DEL_KERNEL(tanh); GGML_METAL_DEL_KERNEL(relu); GGML_METAL_DEL_KERNEL(gelu); + GGML_METAL_DEL_KERNEL(gelu_quick); + GGML_METAL_DEL_KERNEL(silu); GGML_METAL_DEL_KERNEL(soft_max); GGML_METAL_DEL_KERNEL(soft_max_4); GGML_METAL_DEL_KERNEL(diag_mask_inf); @@ -470,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(get_rows_q5_K); GGML_METAL_DEL_KERNEL(get_rows_q6_K); GGML_METAL_DEL_KERNEL(rms_norm); + GGML_METAL_DEL_KERNEL(group_norm); GGML_METAL_DEL_KERNEL(norm); GGML_METAL_DEL_KERNEL(mul_mv_f32_f32); GGML_METAL_DEL_KERNEL(mul_mv_f16_f16); @@ -531,8 +546,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(rope_f16); GGML_METAL_DEL_KERNEL(alibi_f32); GGML_METAL_DEL_KERNEL(im2col_f16); + GGML_METAL_DEL_KERNEL(upscale_f32); + GGML_METAL_DEL_KERNEL(pad_f32); GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc); GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc); + GGML_METAL_DEL_KERNEL(leaky_relu_f32); GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f32); GGML_METAL_DEL_KERNEL(cpy_f32_q8_0); @@ -843,9 +861,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) { switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { - case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: return true; default: return false; @@ -853,11 +873,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) { case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: - case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - case GGML_OP_GET_ROWS: + case GGML_OP_PERMUTE: case GGML_OP_CONCAT: case GGML_OP_ADD: + case GGML_OP_ACC: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_SCALE: @@ -865,11 +885,15 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) { case GGML_OP_SUM_ROWS: case GGML_OP_SOFT_MAX: case GGML_OP_RMS_NORM: + case GGML_OP_GROUP_NORM: case GGML_OP_NORM: case GGML_OP_ALIBI: case GGML_OP_ROPE: case GGML_OP_IM2COL: + case GGML_OP_UPSCALE: + case GGML_OP_PAD: case GGML_OP_ARGSORT: + case GGML_OP_LEAKY_RELU: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return true; @@ -902,8 +926,9 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) { }; } case GGML_OP_DIAG_MASK_INF: + case GGML_OP_GET_ROWS: { - return op->ne[0] % 4 == 0; + return op->ne[3] == 1; } default: return false; @@ -979,7 +1004,10 @@ void ggml_metal_graph_compute( } break; } - GGML_ASSERT(ggml_metal_supports_op(dst)); + if (!ggml_metal_supports_op(dst)) { + GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); + GGML_ASSERT(!"unsupported op"); + } const int64_t ne00 = src0 ? src0->ne[0] : 0; const int64_t ne01 = src0 ? src0->ne[1] : 0; @@ -1076,6 +1104,8 @@ void ggml_metal_graph_compute( case GGML_OP_MUL: case GGML_OP_DIV: { + const size_t offs = 0; + bool bcast_row = false; int64_t nb = ne00; @@ -1134,7 +1164,8 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; if (bcast_row) { const int64_t n = ggml_nelements(dst)/4; @@ -1146,6 +1177,86 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } } break; + case GGML_OP_ACC: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(dstt == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const size_t pnb1 = ((int32_t *) dst->op_params)[0]; + const size_t pnb2 = ((int32_t *) dst->op_params)[1]; + const size_t pnb3 = ((int32_t *) dst->op_params)[2]; + const size_t offs = ((int32_t *) dst->op_params)[3]; + + const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + const int nth = MIN(1024, ne00); + + [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + + [encoder setComputePipelineState:ctx->pipeline_add]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; + [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; + [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; + [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_SCALE: { GGML_ASSERT(ggml_is_contiguous(src0)); @@ -1169,16 +1280,15 @@ void ggml_metal_graph_compute( } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(gf->nodes[i])) { - case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_TANH: { - [encoder setComputePipelineState:ctx->pipeline_silu]; + [encoder setComputePipelineState:ctx->pipeline_tanh]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; const int64_t n = ggml_nelements(dst); - GGML_ASSERT(n % 4 == 0); - [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_UNARY_OP_RELU: { @@ -1199,6 +1309,28 @@ void ggml_metal_graph_compute( const int64_t n = ggml_nelements(dst); GGML_ASSERT(n % 4 == 0); + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + [encoder setComputePipelineState:ctx->pipeline_gelu_quick]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + GGML_ASSERT(n % 4 == 0); + + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SILU: + { + [encoder setComputePipelineState:ctx->pipeline_silu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + GGML_ASSERT(n % 4 == 0); + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; default: @@ -1837,6 +1969,38 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case GGML_OP_GROUP_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + + //float eps; + //memcpy(&eps, dst->op_params, sizeof(float)); + + const float eps = 1e-6f; // TODO: temporarily hardcoded + + const int32_t n_groups = ((int32_t *) dst->op_params)[0]; + + int nth = 32; // SIMD width + + //while (nth < ne00/4 && nth < 1024) { + // nth *= 2; + //} + + [encoder setComputePipelineState:ctx->pipeline_group_norm]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&eps length:sizeof( float) atIndex:9]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_NORM: { float eps; @@ -2006,6 +2170,65 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; } break; + case GGML_OP_UPSCALE: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const int sf = dst->op_params[0]; + + [encoder setComputePipelineState:ctx->pipeline_upscale_f32]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&sf length:sizeof(sf) atIndex:18]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_PAD: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + [encoder setComputePipelineState:ctx->pipeline_pad_f32]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_ARGSORT: { GGML_ASSERT(src0->type == GGML_TYPE_F32); @@ -2027,6 +2250,22 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)]; } break; + case GGML_OP_LEAKY_RELU: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + float slope; + memcpy(&slope, dst->op_params, sizeof(float)); + + [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: diff --git a/ggml-metal.metal b/ggml-metal.metal index 773fac124..fe0ada445 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -79,6 +79,7 @@ kernel void kernel_add( constant int64_t & nb1, constant int64_t & nb2, constant int64_t & nb3, + constant int64_t & offs, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -90,9 +91,9 @@ kernel void kernel_add( const int64_t i12 = i02 % ne12; const int64_t i11 = i01 % ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { const int i10 = i0 % ne10; @@ -204,7 +205,7 @@ kernel void kernel_add_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb [[buffer(27)]], + constant int64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig % nb]; } @@ -213,7 +214,7 @@ kernel void kernel_mul_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb [[buffer(27)]], + constant int64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * src1[tpig % nb]; } @@ -222,7 +223,7 @@ kernel void kernel_div_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb [[buffer(27)]], + constant int64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] / src1[tpig % nb]; } @@ -243,6 +244,47 @@ kernel void kernel_scale_4( dst[tpig] = src0[tpig] * scale; } +kernel void kernel_relu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); +} + +kernel void kernel_tanh( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = precise::tanh(x); +} + +constant float GELU_COEF_A = 0.044715f; +constant float GELU_QUICK_COEF = -1.702f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +kernel void kernel_gelu( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + kernel void kernel_silu( device const float4 * src0, device float4 * dst, @@ -251,13 +293,6 @@ kernel void kernel_silu( dst[tpig] = x / (1.0f + exp(-x)); } -kernel void kernel_relu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - kernel void kernel_sqr( device const float * src0, device float * dst, @@ -313,22 +348,6 @@ kernel void kernel_sum_rows( dst_row[0] = row_sum; } -constant float GELU_COEF_A = 0.044715f; -constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - -kernel void kernel_gelu( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - // BEWARE !!! - // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! - // This was observed with Falcon 7B and 40B models - // - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - kernel void kernel_soft_max( device const float * src0, device const float * src1, @@ -650,6 +669,94 @@ kernel void kernel_rms_norm( } } +kernel void kernel_group_norm( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int32_t & n_groups, + constant float & eps, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t ne = ne00*ne01*ne02; + const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups); + + int start = tgpig * gs; + int end = start + gs; + + start += tpitg; + + if (end >= ne) { + end = ne; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += ntg) { + tmp += src0[j]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float mean = tmp / gs; + tmp = 0.0f; + + for (int j = start; j < end; j += ntg) { + float xi = src0[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float variance = tmp / gs; + const float scale = 1.0f/sqrt(variance + eps); + for (int j = start; j < end; j += ntg) { + dst[j] *= scale; + } +} + // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) // il indicates where the q4 quants begin (0 or QK4_0/4) // we assume that the yl's have been multiplied with the appropriate scale factor @@ -1656,6 +1763,97 @@ kernel void kernel_im2col_f16( } } +kernel void kernel_upscale_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int32_t & sf, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1/sf; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = src0_ptr[i0/sf]; + } +} + +kernel void kernel_pad_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + + if (i1 < ne01 && i2 < ne02 && i3 < ne03) { + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i0 < ne00) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } + } + + return; + } + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = 0.0f; + } +} + // bitonic sort implementation following the CUDA kernels as reference typedef void (argsort_t)( device const float * x, @@ -1708,6 +1906,14 @@ kernel void kernel_argsort_f32_i32( template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; +kernel void kernel_leaky_relu_f32( + device const float * src0, + device float * dst, + constant float & slope, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; +} + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, @@ -2066,9 +2272,9 @@ kernel void kernel_cpy_f32_q4_1( } kernel void kernel_concat( - device const char * src0, - device const char * src1, - device char * dst, + device const char * src0, + device const char * src1, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -2105,7 +2311,7 @@ kernel void kernel_concat( const int64_t i12 = i02 % ne12; const int64_t i11 = i01 % ne11; - device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00; + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; @@ -3315,10 +3521,10 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg template void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { - const half d = xb->d; - const half min = xb->dmin; + const float d = xb->d; + const float min = xb->dmin; device const uint8_t * q = (device const uint8_t *)xb->qs; - half dl, ml; + float dl, ml; uint8_t sc = xb->scales[il]; #if QK_K == 256 @@ -3388,10 +3594,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg q = q + (il/4) * 32 + 16 * (il&1); il = il & 3; const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const half d = il < 2 ? xb->d : xb->d / 16.h; - const half min = xb->dmin; - const half dl = d * sc[0]; - const half ml = min * sc[1]; + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; #else q = q + 16 * (il&1); device const uint8_t * s = xb->scales; @@ -3418,13 +3624,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg uint8_t ul = 1 << (il/2); il = il & 3; const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const half d = il < 2 ? xb->d : xb->d / 16.h; - const half min = xb->dmin; - const half dl = d * sc[0]; - const half ml = min * sc[1]; + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; - const ushort mask = il<2 ? 0x0F : 0xF0; - const half qh_val = il<2 ? 16.h : 256.h; + const ushort mask = il<2 ? 0x0F : 0xF0; + const float qh_val = il<2 ? 16.f : 256.f; for (int i = 0; i < 16; ++i) { reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; } diff --git a/ggml.c b/ggml.c index 66658ff4b..29e18a24c 100644 --- a/ggml.c +++ b/ggml.c @@ -1395,7 +1395,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } -inline static void ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; } +inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } static const float GELU_COEF_A = 0.044715f; static const float GELU_QUICK_COEF = -1.702f; @@ -1623,7 +1623,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "POOL_1D", "POOL_2D", "UPSCALE", + "PAD", "ARGSORT", + "LEAKY_RELU", "FLASH_ATTN", "FLASH_FF", @@ -1650,7 +1652,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70"); +static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1707,7 +1709,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "pool_1d(x)", "pool_2d(x)", "upscale(x)", + "pad(x)", "argsort(x)", + "leaky_relu(x)", "flash_attn(x)", "flash_ff(x)", @@ -1734,7 +1738,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70"); +static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1750,10 +1754,9 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "GELU", "GELU_QUICK", "SILU", - "LEAKY", }; -static_assert(GGML_UNARY_OP_COUNT == 11, "GGML_UNARY_OP_COUNT != 11"); +static_assert(GGML_UNARY_OP_COUNT == 10, "GGML_UNARY_OP_COUNT != 10"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -3830,12 +3833,25 @@ struct ggml_tensor * ggml_relu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU); } -// ggml_leaky +// ggml_leaky_relu -struct ggml_tensor * ggml_leaky( +struct ggml_tensor * ggml_leaky_relu( struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_unary(ctx, a, GGML_UNARY_OP_LEAKY); + struct ggml_tensor * a, float negative_slope, bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + ggml_set_op_params(result, &negative_slope, sizeof(negative_slope)); + + result->op = GGML_OP_LEAKY_RELU; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; } // ggml_gelu @@ -4022,8 +4038,9 @@ static struct ggml_tensor * ggml_group_norm_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_GROUP_NORM; result->op_params[0] = n_groups; + + result->op = GGML_OP_GROUP_NORM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = NULL; // TODO: maybe store epsilon here? @@ -5523,6 +5540,30 @@ static struct ggml_tensor * ggml_upscale_impl( return result; } +struct ggml_tensor * ggml_pad( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, int p1, int p2, int p3) { + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, + a->ne[0] + p0, + a->ne[1] + p1, + a->ne[2] + p2, + a->ne[3] + p3); + + result->op = GGML_OP_PAD; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + struct ggml_tensor * ggml_upscale( struct ggml_context * ctx, struct ggml_tensor * a, @@ -7718,8 +7759,10 @@ static void ggml_compute_forward_mul_f32( const int ith = params->ith; const int nth = params->nth; +// TODO: OpenCL kernel support broadcast #ifdef GGML_USE_CLBLAST if (src1->backend == GGML_BACKEND_GPU) { + GGML_ASSERT(ggml_are_same_shape(src0, src1)); if (ith == 0) { ggml_cl_mul(src0, src1, dst); } @@ -8985,10 +9028,9 @@ static void ggml_compute_forward_silu( } break; } } +// ggml_compute_forward_leaky_relu -// ggml_compute_forward_leaky - -static void ggml_compute_forward_leaky_f32( +static void ggml_compute_forward_leaky_relu_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { @@ -9002,24 +9044,27 @@ static void ggml_compute_forward_leaky_f32( const int n = ggml_nrows(src0); const int nc = src0->ne[0]; + float negative_slope; + memcpy(&negative_slope, dst->op_params, sizeof(float)); + assert(dst->nb[0] == sizeof(float)); assert(src0->nb[0] == sizeof(float)); for (int i = 0; i < n; i++) { - ggml_vec_leaky_f32(nc, + ggml_vec_leaky_relu_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); + (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope); } } -static void ggml_compute_forward_leaky( +static void ggml_compute_forward_leaky_relu( const struct ggml_compute_params * params, const struct ggml_tensor * src0, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_leaky_f32(params, src0, dst); + ggml_compute_forward_leaky_relu_f32(params, src0, dst); } break; default: { @@ -12158,6 +12203,7 @@ static void ggml_compute_forward_upscale_f32( GGML_ASSERT(src0->nb[0] == sizeof(float)); const int ith = params->ith; + const int nth = params->nth; GGML_TENSOR_UNARY_OP_LOCALS @@ -12165,16 +12211,17 @@ static void ggml_compute_forward_upscale_f32( // TODO: optimize - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = ith; i02 < ne02; i02++) { - for (int m = 0; m < dst->ne[1]; m++) { - int i01 = m / scale_factor; - for (int n = 0; n < dst->ne[0]; n++) { - int i00 = n / scale_factor; + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const int64_t i01 = i1 / scale_factor; + for (int64_t i0 = 0; i0 < ne0; i0++) { + const int64_t i00 = i0 / scale_factor; - const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03); - - float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]); + const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); *y = *x; } @@ -12199,6 +12246,64 @@ static void ggml_compute_forward_upscale( } } +// ggml_compute_forward_pad + +static void ggml_compute_forward_pad_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT( dst->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float * dst_ptr = (float *) dst->data; + + // TODO: optimize + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = ith; i1 < ne1; i1 += nth) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + for (int64_t i3 = 0; i3 < ne3; ++i3) { + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + + const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + dst_ptr[dst_idx] = *src_ptr; + } else { + dst_ptr[dst_idx] = 0; + } + } + } + } + } +} + +static void ggml_compute_forward_pad( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_pad_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_argsort static void ggml_compute_forward_argsort_f32( @@ -13406,10 +13511,6 @@ static void ggml_compute_forward_unary( { ggml_compute_forward_silu(params, src0, dst); } break; - case GGML_UNARY_OP_LEAKY: - { - ggml_compute_forward_leaky(params, src0, dst); - } break; default: { GGML_ASSERT(false); @@ -14191,10 +14292,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_upscale(params, tensor->src[0], tensor); } break; + case GGML_OP_PAD: + { + ggml_compute_forward_pad(params, tensor->src[0], tensor); + } break; case GGML_OP_ARGSORT: { ggml_compute_forward_argsort(params, tensor->src[0], tensor); } break; + case GGML_OP_LEAKY_RELU: + { + ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor); + } break; case GGML_OP_FLASH_ATTN: { const int32_t t = ggml_get_op_params_i32(tensor, 0); @@ -15187,10 +15296,18 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_PAD: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_ARGSORT: { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_LEAKY_RELU: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_FLASH_ATTN: { struct ggml_tensor * flash_grad = NULL; @@ -15796,6 +15913,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_ARGMAX: case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: + case GGML_OP_LEAKY_RELU: { n_tasks = 1; } break; @@ -15808,7 +15926,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_LEAKY: { n_tasks = 1; } break; @@ -15927,6 +16044,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = n_threads; } break; + case GGML_OP_PAD: + { + n_tasks = n_threads; + } break; case GGML_OP_ARGSORT: { n_tasks = n_threads; diff --git a/ggml.h b/ggml.h index 32f256481..1447646b1 100644 --- a/ggml.h +++ b/ggml.h @@ -423,7 +423,9 @@ extern "C" { GGML_OP_POOL_1D, GGML_OP_POOL_2D, GGML_OP_UPSCALE, // nearest interpolate + GGML_OP_PAD, GGML_OP_ARGSORT, + GGML_OP_LEAKY_RELU, GGML_OP_FLASH_ATTN, GGML_OP_FLASH_FF, @@ -463,7 +465,6 @@ extern "C" { GGML_UNARY_OP_GELU, GGML_UNARY_OP_GELU_QUICK, GGML_UNARY_OP_SILU, - GGML_UNARY_OP_LEAKY, GGML_UNARY_OP_COUNT, }; @@ -793,6 +794,9 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // dst = a + // view(dst, nb1, nb2, nb3, offset) += b + // return dst GGML_API struct ggml_tensor * ggml_acc( struct ggml_context * ctx, struct ggml_tensor * a, @@ -957,15 +961,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - GGML_API struct ggml_tensor * ggml_leaky( + GGML_API struct ggml_tensor * ggml_leaky_relu( struct ggml_context * ctx, - struct ggml_tensor * a); + struct ggml_tensor * a, float negative_slope, bool inplace); GGML_API struct ggml_tensor * ggml_relu_inplace( struct ggml_context * ctx, struct ggml_tensor * a); - // TODO: double-check this computation is correct GGML_API struct ggml_tensor * ggml_gelu( struct ggml_context * ctx, struct ggml_tensor * a); @@ -1551,6 +1554,15 @@ extern "C" { struct ggml_tensor * a, int scale_factor); + // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0] + GGML_API struct ggml_tensor * ggml_pad( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1, + int p2, + int p3); + // sort rows enum ggml_sort_order { GGML_SORT_ASC, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 44830b4d4..afca85143 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -234,6 +234,11 @@ static bool ggml_is_view_op(enum ggml_op op) { return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE; } +enum test_mode { + MODE_TEST, + MODE_PERF, +}; + struct test_case { virtual ~test_case() {} @@ -268,7 +273,58 @@ struct test_case { return size; } + ggml_cgraph * gf = nullptr; + + static const int sentinel_size = 1024; + + test_mode mode; + + std::vector sentinels; + + void add_sentinel(ggml_context * ctx) { + if (mode == MODE_PERF) { + return; + } + ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size); + ggml_format_name(sentinel, "sent_%zu", sentinels.size()); + sentinels.push_back(sentinel); + } + + // hijack ggml_new_tensor to add sentinels after each tensor to check for overflows in the backend + + ggml_tensor * ggml_new_tensor(ggml_context * ctx, ggml_type type, int n_dims, const int64_t * ne) { + ggml_tensor * t = ::ggml_new_tensor(ctx, type, n_dims, ne); + add_sentinel(ctx); + return t; + } + + ggml_tensor * ggml_new_tensor_1d(ggml_context * ctx, ggml_type type, int64_t ne0) { + ggml_tensor * t = ::ggml_new_tensor_1d(ctx, type, ne0); + add_sentinel(ctx); + return t; + } + + ggml_tensor * ggml_new_tensor_2d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1) { + ggml_tensor * t = ::ggml_new_tensor_2d(ctx, type, ne0, ne1); + add_sentinel(ctx); + return t; + } + + ggml_tensor * ggml_new_tensor_3d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2) { + ggml_tensor * t = ::ggml_new_tensor_3d(ctx, type, ne0, ne1, ne2); + add_sentinel(ctx); + return t; + } + + ggml_tensor * ggml_new_tensor_4d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) { + ggml_tensor * t = ::ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3); + add_sentinel(ctx); + return t; + } + bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name) { + mode = MODE_TEST; + ggml_init_params params = { /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(), /* .mem_base = */ NULL, @@ -276,6 +332,11 @@ struct test_case { }; ggml_context * ctx = ggml_init(params); + gf = ggml_new_graph(ctx); + + // pre-graph sentinel + add_sentinel(ctx); + ggml_tensor * out = build_graph(ctx); if (op_name != nullptr && op_desc(out) != op_name) { @@ -296,13 +357,20 @@ struct test_case { } } + // post-graph sentinel + add_sentinel(ctx); + // allocate ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1); // build graph - ggml_cgraph * gf = ggml_new_graph(ctx); ggml_build_forward_expand(gf, out); + // add sentinels as graph nodes so that they are checked in the callback + for (ggml_tensor * sentinel : sentinels) { + gf->nodes[gf->n_nodes++] = sentinel; + } + // randomize tensors initialize_tensors(ctx); @@ -318,9 +386,24 @@ struct test_case { }; auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { + callback_userdata * ud = (callback_userdata *) user_data; + + if (t1->op == GGML_OP_NONE) { + // sentinels must be unchanged + std::vector t1_data(ggml_nbytes(t1)); + std::vector t2_data(ggml_nbytes(t2)); + ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1)); + ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2)); + + if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) { + printf("sentinel mismatch: %s ", t1->name); + ud->ok = false; + return true; + } + } + std::vector f1 = tensor_to_float(t1); std::vector f2 = tensor_to_float(t2); - callback_userdata * ud = (callback_userdata *) user_data; for (size_t i = 0; i < f1.size(); i++) { // check for nans @@ -349,9 +432,10 @@ struct test_case { if (err > ud->max_err) { printf("[%s] NMSE = %f ", ggml_op_desc(t1), err); //for (int i = 0; i < f1.size(); i++) { - // printf("(%f, %f) ", f1[i], f2[i]); + // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); //} //printf("\n"); + //exit(1); ud->ok = false; } return true; @@ -375,6 +459,8 @@ struct test_case { } bool eval_perf(ggml_backend_t backend, const char * op_name) { + mode = MODE_PERF; + static const size_t graph_nodes = 8192; ggml_init_params params = { @@ -1135,6 +1221,118 @@ struct test_sum_rows : public test_case { } }; +// GGML_OP_UPSCALE +struct test_upscale : public test_case { + const ggml_type type; + const std::array ne; + const int32_t scale_factor; + + std::string vars() override { + return VARS_TO_STR3(type, ne, scale_factor); + } + + test_upscale(ggml_type type = GGML_TYPE_F32, + std::array ne = {512, 512, 3, 1}, + int32_t scale_factor = 2) + : type(type), ne(ne), scale_factor(scale_factor) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_upscale(ctx, a, scale_factor); + return out; + } +}; + +// GGML_OP_GROUP_NORM +struct test_group_norm : public test_case { + const ggml_type type; + const std::array ne; + const int32_t num_groups; + + std::string vars() override { + return VARS_TO_STR3(type, ne, num_groups); + } + + test_group_norm(ggml_type type = GGML_TYPE_F32, + std::array ne = {64, 64, 320, 1}, + int32_t num_groups = 32) + : type(type), ne(ne), num_groups(num_groups) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_group_norm(ctx, a, num_groups); + return out; + } +}; + +// GGML_OP_ACC +struct test_acc : public test_case { + const ggml_type type; + const std::array ne_a; + const std::array ne_b; + + std::string vars() override { + return VARS_TO_STR3(type, ne_a, ne_b); + } + + test_acc(ggml_type type = GGML_TYPE_F32, + std::array ne_a = {1024, 577, 1, 1}, + std::array ne_b = {1024, 576, 1, 1}) + : type(type), ne_a(ne_a), ne_b(ne_b) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); + ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]); + return out; + } +}; + +// GGML_OP_PAD +struct test_pad : public test_case { + const ggml_type type; + const std::array ne_a; + const int pad_0; + const int pad_1; + + std::string vars() override { + return VARS_TO_STR4(type, ne_a, pad_0, pad_1); + } + + test_pad(ggml_type type = GGML_TYPE_F32, + std::array ne_a = {512, 512, 1, 1}, + int pad_0 = 1, int pad_1 = 1) + : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_tensor * out = ggml_pad(ctx, a, pad_0, pad_1, 0, 0); + return out; + } +}; + +// GGML_OP_LEAKY_RELU +struct test_leaky_relu : public test_case { + const ggml_type type; + const std::array ne_a; + const float negative_slope; + + std::string vars() override { + return VARS_TO_STR3(type, ne_a, negative_slope); + } + + test_leaky_relu(ggml_type type = GGML_TYPE_F32, + std::array ne_a = {10, 10, 10, 10}, + float negative_slope = 0.1f) + : type(type), ne_a(ne_a), negative_slope(negative_slope) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_tensor * out = ggml_leaky_relu(ctx, a, negative_slope, true); + return out; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1219,11 +1417,6 @@ struct test_moe : public test_case { } }; -enum test_mode { - MODE_TEST, - MODE_PERF, -}; - static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { std::vector> test_cases; @@ -1372,12 +1565,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); } - test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {10, 10, 10, 10})); - test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, {2, 1, 1, 1})); + test_cases.emplace_back(new test_sum_rows()); + test_cases.emplace_back(new test_upscale()); + test_cases.emplace_back(new test_group_norm()); + test_cases.emplace_back(new test_acc()); + test_cases.emplace_back(new test_pad()); + test_cases.emplace_back(new test_leaky_relu()); #if !defined(__SANITIZE_THREAD__) // FIXME: these tests use too much memory with thread sanitizer - test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 14336)); + test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024)); //test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336)); #endif