diff --git a/ggml.c b/ggml.c index 037f0bc99..14e08f9d6 100644 --- a/ggml.c +++ b/ggml.c @@ -112,6 +112,7 @@ typedef void* thread_ret_t; /*#define GGML_PERF*/ #define GGML_DEBUG 0 #define GGML_GELU_FP16 +#define GGML_GELU_QUICK_FP16 #define GGML_SILU_FP16 #define GGML_SOFT_MAX_UNROLL 4 @@ -340,6 +341,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { // precomputed gelu table for f16 (128 KB) static ggml_fp16_t table_gelu_f16[1 << 16]; +// precomputed quick gelu table for f16 (128 KB) +static ggml_fp16_t table_gelu_quick_f16[1 << 16]; + // precomputed silu table for f16 (128 KB) static ggml_fp16_t table_silu_f16[1 << 16]; @@ -1677,14 +1681,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) { #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) #define GGML_F32x4_REDUCE(res, x) \ { \ - for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ - x[2*i] = vaddq_f32(x[2*i], x[2*i+1]); \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f32(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ - x[4*i] = vaddq_f32(x[4*i], x[4*i+2]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f32(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ - x[8*i] = vaddq_f32(x[8*i], x[8*i+4]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f32(x[i], x[offset+i]); \ } \ res = GGML_F32x4_REDUCE_ONE(x[0]); \ } @@ -1715,14 +1722,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) { #define GGML_F16x8_MUL vmulq_f16 #define GGML_F16x8_REDUCE(res, x) \ { \ - for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ - x[2*i] = vaddq_f16(x[2*i], x[2*i+1]); \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f16(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ - x[4*i] = vaddq_f16(x[4*i], x[4*i+2]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f16(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ - x[8*i] = vaddq_f16(x[8*i], x[8*i+4]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vaddq_f16(x[i], x[offset+i]); \ } \ const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ @@ -1789,14 +1799,17 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) { #define GGML_F32x8_MUL _mm256_mul_ps #define GGML_F32x8_REDUCE(res, x) \ { \ - for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ - x[2*i] = _mm256_add_ps(x[2*i], x[2*i+1]); \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ - x[4*i] = _mm256_add_ps(x[4*i], x[4*i+2]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ - x[8*i] = _mm256_add_ps(x[8*i], x[8*i+4]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ } \ const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ _mm256_extractf128_ps(x[0], 1)); \ @@ -1886,14 +1899,17 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { #define GGML_F32x4_MUL vec_mul #define GGML_F32x4_REDUCE(res, x) \ { \ - for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ - x[2*i] = vec_add(x[2*i], x[2*i+1]); \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ - x[4*i] = vec_add(x[4*i], x[4*i+2]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ - x[8*i] = vec_add(x[8*i], x[8*i+4]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ } \ res = vec_extract(x[0], 0) + \ vec_extract(x[0], 1) + \ @@ -1949,14 +1965,17 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { #define GGML_F32x4_MUL wasm_f32x4_mul #define GGML_F32x4_REDUCE(res, x) \ { \ - for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ - x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ - x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ - x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ } \ res = wasm_f32x4_extract_lane(x[0], 0) + \ wasm_f32x4_extract_lane(x[0], 1) + \ @@ -2011,14 +2030,17 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { #define GGML_F16x4_MUL wasm_f32x4_mul #define GGML_F16x4_REDUCE(res, x) \ { \ - for (int i = 0; i < GGML_F16_ARR/2; ++i) { \ - x[2*i] = wasm_f32x4_add(x[2*i], x[2*i+1]); \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F16_ARR/4; ++i) { \ - x[4*i] = wasm_f32x4_add(x[4*i], x[4*i+2]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F16_ARR/8; ++i) { \ - x[8*i] = wasm_f32x4_add(x[8*i], x[8*i+4]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ } \ res = wasm_f32x4_extract_lane(x[0], 0) + \ wasm_f32x4_extract_lane(x[0], 1) + \ @@ -2060,14 +2082,17 @@ inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { #define GGML_F32x4_MUL _mm_mul_ps #define GGML_F32x4_REDUCE(res, x) \ { \ - for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ - x[2*i] = _mm_add_ps(x[2*i], x[2*i+1]); \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F32_ARR/4; ++i) { \ - x[4*i] = _mm_add_ps(x[4*i], x[4*i+2]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ } \ - for (int i = 0; i < GGML_F32_ARR/8; ++i) { \ - x[8*i] = _mm_add_ps(x[8*i], x[8*i+4]); \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ } \ const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \ res = _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \ @@ -3356,6 +3381,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } static const float GELU_COEF_A = 0.044715f; +static const float GELU_QUICK_COEF = -1.702f; static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; inline static float ggml_gelu_f32(float x) { @@ -3386,6 +3412,34 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { } #endif +inline static float ggml_gelu_quick_f32(float x) { + return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); +} + +//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { +// const uint16_t * i16 = (const uint16_t *) x; +// for (int i = 0; i < n; ++i) { +// y[i] = table_gelu_quick_f16[i16[i]]; +// } +//} + +#ifdef GGML_GELU_QUICK_FP16 +inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]); + } +} +#else +inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_quick_f32(x[i]); + } +} +#endif + // Sigmoid Linear Unit (SiLU) function inline static float ggml_silu_f32(float x) { return x/(1.0f + expf(-x)); @@ -3616,6 +3670,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "STEP", "RELU", "GELU", + "GELU_QUICK", "SILU", "SILU_BACK", "NORM", @@ -3644,12 +3699,15 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ROPE_BACK", "ALIBI", "CLAMP", - "CONV_1D_1S", - "CONV_1D_2S", + "CONV_1D_S1_PH", + "CONV_1D_S2_PH", + "CONV_2D_SK_P0", "FLASH_ATTN", "FLASH_FF", "FLASH_ATTN_BACK", + "WIN_PART", + "WIN_UNPART", "MAP_UNARY", "MAP_BINARY", @@ -3658,7 +3716,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57"); +static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3684,6 +3742,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "step(x)", "relu(x)", "gelu(x)", + "gelu_quick(x)", "silu(x)", "silu_back(x)", "norm(x)", @@ -3712,12 +3771,15 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rope_back(x)", "alibi(x)", "clamp(x)", - "conv_1d_1s(x)", - "conv_1d_2s(x)", + "conv_1d_s1_ph(x)", + "conv_1d_s2_ph(x)", + "conv_2d_sk_p0(x)", "flash_attn(x)", "flash_ff(x)", "flash_attn_back(x)", + "win_part(x)", + "win_unpart(x)", "f(x)", "f(x,y)", @@ -3726,7 +3788,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 57, "GGML_OP_COUNT != 57"); +static_assert(GGML_OP_COUNT == 61, "GGML_OP_COUNT != 61"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -4017,7 +4079,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { // initialize time system (required on Windows) ggml_time_init(); - // initialize GELU, SILU and EXP F32 tables + // initialize GELU, Quick GELU, SILU and EXP F32 tables { const uint64_t t_start = ggml_time_us(); UNUSED(t_start); @@ -4027,13 +4089,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { memcpy(&ii, &ui, sizeof(ii)); const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); + table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f)); } const uint64_t t_end = ggml_time_us(); UNUSED(t_end); - GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); } // initialize g_state @@ -4665,9 +4728,10 @@ const char * ggml_get_name(const struct ggml_tensor * tensor) { return tensor->name; } -void ggml_set_name(struct ggml_tensor * tensor, const char * name) { +struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) { strncpy(tensor->name, name, sizeof(tensor->name)); tensor->name[sizeof(tensor->name) - 1] = '\0'; + return tensor; } struct ggml_tensor * ggml_view_tensor( @@ -5446,6 +5510,40 @@ struct ggml_tensor * ggml_gelu_inplace( return ggml_gelu_impl(ctx, a, true); } +// ggml_gelu_quick + +struct ggml_tensor * ggml_gelu_quick_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_GELU_QUICK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_gelu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_gelu_quick_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_gelu_quick_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_gelu_quick_impl(ctx, a, true); +} + // ggml_silu struct ggml_tensor * ggml_silu_impl( @@ -6645,7 +6743,7 @@ struct ggml_tensor * ggml_clamp( ggml_scratch_save(ctx); - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2); ((float *) b->data)[0] = min; ((float *) b->data)[1] = max; @@ -6660,9 +6758,9 @@ struct ggml_tensor * ggml_clamp( return result; } -// ggml_conv_1d_1s +// ggml_conv_1d_s1_ph -struct ggml_tensor * ggml_conv_1d_1s( +struct ggml_tensor * ggml_conv_1d_s1_ph( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { @@ -6679,7 +6777,7 @@ struct ggml_tensor * ggml_conv_1d_1s( const int64_t ne[4] = { b->ne[0], a->ne[2], 1, 1, }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); - result->op = GGML_OP_CONV_1D_1S; + result->op = GGML_OP_CONV_1D_S1_PH; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; @@ -6687,9 +6785,9 @@ struct ggml_tensor * ggml_conv_1d_1s( return result; } -// ggml_conv_1d_2s +// ggml_conv_1d_s2_ph -struct ggml_tensor * ggml_conv_1d_2s( +struct ggml_tensor * ggml_conv_1d_s2_ph( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { @@ -6706,7 +6804,35 @@ struct ggml_tensor * ggml_conv_1d_2s( const int64_t ne[4] = { b->ne[0]/2, a->ne[2], 1, 1, }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); - result->op = GGML_OP_CONV_1D_2S; + result->op = GGML_OP_CONV_1D_S2_PH; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + +// ggml_conv_2d_sk_p0 + +struct ggml_tensor * ggml_conv_2d_sk_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(b->ne[3] == 1); + GGML_ASSERT(a->ne[2] == b->ne[2]); + GGML_ASSERT(b->ne[0] % a->ne[0] == 0); + GGML_ASSERT(b->ne[1] % a->ne[1] == 0); + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { b->ne[0]/a->ne[0], b->ne[1]/a->ne[1], a->ne[3], 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_CONV_2D_SK_P0; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; @@ -6840,6 +6966,89 @@ struct ggml_tensor * ggml_flash_attn_back( return result; } +// ggml_win_part + +struct ggml_tensor * ggml_win_part( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w) { + GGML_ASSERT(a->ne[3] == 1); + GGML_ASSERT(a->type == GGML_TYPE_F32); + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + // padding + const int px = (w - a->ne[1]%w)%w; + const int py = (w - a->ne[2]%w)%w; + + const int npx = (px + a->ne[1])/w; + const int npy = (py + a->ne[2])/w; + const int np = npx*npy; + + const int64_t ne[4] = { a->ne[0], w, w, np, }; + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_scratch_save(ctx); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + + ((int32_t *) b->data)[0] = npx; + ((int32_t *) b->data)[1] = npy; + ((int32_t *) b->data)[2] = w; + + ggml_scratch_load(ctx); + + result->op = GGML_OP_WIN_PART; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + result->opt[0] = b; + + return result; +} + +// ggml_win_unpart + +struct ggml_tensor * ggml_win_unpart( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int w) { + GGML_ASSERT(a->type == GGML_TYPE_F32); + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); + + ggml_scratch_save(ctx); + + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + + ((int32_t *) b->data)[0] = w; + + ggml_scratch_load(ctx); + + result->op = GGML_OP_WIN_UNPART; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + result->opt[0] = b; + + return result; +} // ggml_map_unary @@ -9479,8 +9688,65 @@ static void ggml_compute_forward_gelu( GGML_ASSERT(false); } break; } +} - //printf("XXXXXXXX gelu\n"); +// ggml_compute_forward_gelu_quick + +static void ggml_compute_forward_gelu_quick_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_quick_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_gelu_quick( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gelu_quick_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } } // ggml_compute_forward_silu @@ -10878,7 +11144,7 @@ static void ggml_compute_forward_set_f32( const int im2 = (ne12 == 0 ? 0 : ne12-1); const int im3 = (ne13 == 0 ? 0 : ne13-1); - GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 < ggml_nbytes(dst)); + GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst)); GGML_ASSERT(nb10 == sizeof(float)); @@ -11599,8 +11865,9 @@ static void ggml_compute_forward_alibi_f32( const struct ggml_tensor * src1, struct ggml_tensor * dst) { assert(params->ith == 0); - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 3); + + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_nelements(src1) == 3); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -11663,8 +11930,9 @@ static void ggml_compute_forward_alibi_f16( const struct ggml_tensor * src1, struct ggml_tensor * dst) { assert(params->ith == 0); - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 3); + + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_nelements(src1) == 3); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -11766,15 +12034,16 @@ static void ggml_compute_forward_clamp_f32( const struct ggml_tensor * src1, struct ggml_tensor * dst) { assert(params->ith == 0); - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 2); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_nelements(src1) == 2); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const int min = ((float *) src1->data)[0]; - const int max = ((float *) src1->data)[1]; + const float min = ((float *) src1->data)[0]; + const float max = ((float *) src1->data)[1]; const int ith = params->ith; const int nth = params->nth; @@ -12332,9 +12601,9 @@ static void ggml_compute_forward_rope_back( } } -// ggml_compute_forward_conv_1d_1s +// ggml_compute_forward_conv_1d_s1_ph -static void ggml_compute_forward_conv_1d_1s_f16_f32( +static void ggml_compute_forward_conv_1d_s1_ph_f16_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -12454,7 +12723,7 @@ static void ggml_compute_forward_conv_1d_1s_f16_f32( } } -static void ggml_compute_forward_conv_1d_1s_f32( +static void ggml_compute_forward_conv_1d_s1_ph_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -12574,7 +12843,7 @@ static void ggml_compute_forward_conv_1d_1s_f32( } } -static void ggml_compute_forward_conv_1d_1s( +static void ggml_compute_forward_conv_1d_s1_ph( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -12582,11 +12851,11 @@ static void ggml_compute_forward_conv_1d_1s( switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_conv_1d_1s_f16_f32(params, src0, src1, dst); + ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst); + ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst); } break; default: { @@ -12595,9 +12864,9 @@ static void ggml_compute_forward_conv_1d_1s( } } -// ggml_compute_forward_conv_1d_2s +// ggml_compute_forward_conv_1d_s2_ph -static void ggml_compute_forward_conv_1d_2s_f16_f32( +static void ggml_compute_forward_conv_1d_s2_ph_f16_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -12717,7 +12986,7 @@ static void ggml_compute_forward_conv_1d_2s_f16_f32( } } -static void ggml_compute_forward_conv_1d_2s_f32( +static void ggml_compute_forward_conv_1d_s2_ph_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -12837,7 +13106,7 @@ static void ggml_compute_forward_conv_1d_2s_f32( } } -static void ggml_compute_forward_conv_1d_2s( +static void ggml_compute_forward_conv_1d_s2_ph( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, @@ -12845,11 +13114,148 @@ static void ggml_compute_forward_conv_1d_2s( switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_conv_1d_2s_f16_f32(params, src0, src1, dst); + ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst); + ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_conv_2d_sk_p0 + +static void ggml_compute_forward_conv_2d_sk_p0_f16_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + //const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + //const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + //const int ne13 = src1->ne[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + //const int ne3 = dst->ne[3]; + //const int ne = ne0*ne1*ne2*ne3; + + const int nb00 = src0->nb[0]; + //const int nb01 = src0->nb[1]; + //const int nb02 = src0->nb[2]; + const int nb03 = src0->nb[3]; + + const int nb10 = src1->nb[0]; + //const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + //const int nb13 = src1->nb[3]; + + //const int nb0 = dst->nb[0]; + //const int nb1 = dst->nb[1]; + const int nb2 = dst->nb[2]; + //const int nb3 = dst->nb[3]; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk0 = ne00; + const int nk1 = ne01; + + // size of the convolution row - the kernel size unrolled across all channels + // round-up so it is more suitable for SIMD + const int ew0 = ggml_up32(nk0*nk1*ne02); + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + // TODO: fix this memset (wsize is overestimated) + memset(params->wdata, 0, params->wsize); + + // prepare source data (src1) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i12 = 0; i12 < ne12; i12++) { + const float * const src = (float *)((char *) src1->data + i12*nb12); + ggml_fp16_t * dst_data = wdata; + + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + for (int ik1 = 0; ik1 < nk1; ik1++) { + for (int ik0 = 0; ik0 < nk0; ik0++) { + dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] = + GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]); + } + } + } + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + // total patches in dst + const int np = ne2; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int i2 = ip0; i2 < ip1; i2++) { + float * dst_data = (float *)((char *) dst->data + i2*nb2); + + for (int i1 = 0; i1 < ne1; ++i1) { + for (int i0 = 0; i0 < ne0; ++i0) { + ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0, + (ggml_fp16_t *) ((char *) src0->data + i2*nb03), + (ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0); + } + } + } +} + +static void ggml_compute_forward_conv_2d_sk_p0( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst); + } break; + case GGML_TYPE_F32: + { + //ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst); + GGML_ASSERT(false); } break; default: { @@ -13952,6 +14358,145 @@ static void ggml_compute_forward_flash_attn_back( } } +// ggml_compute_forward_win_part + +static void ggml_compute_forward_win_part_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int64_t ne00 = src0->ne[0]; UNUSED(ne00); + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; UNUSED(ne03); + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; UNUSED(ne3); + + const int32_t nep0 = ((const int32_t *)(opt0->data))[0]; + const int32_t nep1 = ((const int32_t *)(opt0->data))[1]; + const int32_t w = ((const int32_t *)(opt0->data))[2]; + + assert(ne00 == ne0); + assert(ne3 == nep0*nep1); + + // TODO: optimize / multi-thread + for (int py = 0; py < nep1; ++py) { + for (int px = 0; px < nep0; ++px) { + const int64_t i3 = py*nep0 + px; + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i02 = py*w + i2; + const int64_t i01 = px*w + i1; + const int64_t i00 = i0; + + const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0; + const int64_t j = i02*ne01*ne00 + i01*ne00 + i00; + + if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { + ((float *) dst->data)[i] = 0.0f; + } else { + ((float *) dst->data)[i] = ((float *) src0->data)[j]; + } + } + } + } + } + } +} + +static void ggml_compute_forward_win_part( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_win_part_f32(params, src0, opt0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_win_unpart + +static void ggml_compute_forward_win_unpart_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + //const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + + const int32_t w = ((const int32_t *)(opt0->data))[0]; + + // padding + const int px = (w - ne1%w)%w; + //const int py = (w - ne2%w)%w; + + const int npx = (px + ne1)/w; + //const int npy = (py + ne2)/w; + + assert(ne0 == ne00); + + // TODO: optimize / multi-thread + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int ip2 = i2/w; + const int ip1 = i1/w; + + const int64_t i02 = i2%w; + const int64_t i01 = i1%w; + const int64_t i00 = i0; + + const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00; + const int64_t j = i2*ne1*ne0 + i1*ne0 + i0; + + ((float *) dst->data)[j] = ((float *) src0->data)[i]; + } + } + } +} + +static void ggml_compute_forward_win_unpart( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_win_unpart_f32(params, src0, opt0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_map_unary static void ggml_compute_forward_map_unary_f32( @@ -14424,6 +14969,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_gelu(params, tensor->src0, tensor); } break; + case GGML_OP_GELU_QUICK: + { + ggml_compute_forward_gelu_quick(params, tensor->src0, tensor); + } break; case GGML_OP_SILU: { ggml_compute_forward_silu(params, tensor->src0, tensor); @@ -14528,19 +15077,23 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor); } break; - case GGML_OP_CONV_1D_1S: + case GGML_OP_CONV_1D_S1_PH: { - ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor); + ggml_compute_forward_conv_1d_s1_ph(params, tensor->src0, tensor->src1, tensor); } break; - case GGML_OP_CONV_1D_2S: + case GGML_OP_CONV_1D_S2_PH: { - ggml_compute_forward_conv_1d_2s(params, tensor->src0, tensor->src1, tensor); + ggml_compute_forward_conv_1d_s2_ph(params, tensor->src0, tensor->src1, tensor); + } break; + case GGML_OP_CONV_2D_SK_P0: + { + ggml_compute_forward_conv_2d_sk_p0(params, tensor->src0, tensor->src1, tensor); } break; case GGML_OP_FLASH_ATTN: { - int32_t t = ggml_get_i32_1d(tensor->opt[1], 0); + const int32_t t = ggml_get_i32_1d(tensor->opt[1], 0); GGML_ASSERT(t == 0 || t == 1); - bool masked = t != 0; + const bool masked = t != 0; ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor); } break; case GGML_OP_FLASH_FF: @@ -14554,6 +15107,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm bool masked = t != 0; ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor); } break; + case GGML_OP_WIN_PART: + { + ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor); + } break; + case GGML_OP_WIN_UNPART: + { + ggml_compute_forward_win_unpart(params, tensor->src0, tensor->opt[0], tensor); + } break; case GGML_OP_MAP_UNARY: { const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data); @@ -14825,6 +15386,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_GELU_QUICK: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_ALIBI: { GGML_ASSERT(false); // TODO: not implemented @@ -15187,11 +15752,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // noop } } break; - case GGML_OP_CONV_1D_1S: + case GGML_OP_CONV_1D_S1_PH: { GGML_ASSERT(false); // TODO: not implemented } break; - case GGML_OP_CONV_1D_2S: + case GGML_OP_CONV_1D_S2_PH: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_CONV_2D_SK_P0: { GGML_ASSERT(false); // TODO: not implemented } break; @@ -15360,6 +15929,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // not supported } break; + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: { @@ -15768,6 +16339,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) } break; case GGML_OP_MUL: case GGML_OP_GELU: + case GGML_OP_GELU_QUICK: case GGML_OP_SILU: case GGML_OP_SILU_BACK: case GGML_OP_NORM: @@ -15874,8 +16446,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { node->n_tasks = 1; //TODO } break; - case GGML_OP_CONV_1D_1S: - case GGML_OP_CONV_1D_2S: + case GGML_OP_CONV_1D_S1_PH: + case GGML_OP_CONV_1D_S2_PH: { node->n_tasks = n_threads; @@ -15902,6 +16474,41 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) GGML_ASSERT(false); } + work_size = MAX(work_size, cur); + } break; + case GGML_OP_CONV_2D_SK_P0: + { + node->n_tasks = n_threads; + + GGML_ASSERT(node->src1->ne[3] == 1); + + const int64_t ne00 = node->src0->ne[0]; // W + const int64_t ne01 = node->src0->ne[1]; // H + const int64_t ne02 = node->src0->ne[2]; // C + const int64_t ne03 = node->src0->ne[3]; // N + + const int64_t ne10 = node->src1->ne[0]; // W + const int64_t ne11 = node->src1->ne[1]; // H + const int64_t ne12 = node->src1->ne[2]; // C + + const int64_t nk = ne00*ne01; + + UNUSED(ne02); + UNUSED(ne03); + UNUSED(nk); + + size_t cur = 0; + + if (node->src0->type == GGML_TYPE_F16 && + node->src1->type == GGML_TYPE_F32) { + cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12); + } else if (node->src0->type == GGML_TYPE_F32 && + node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)* (ne10*ne11*ne12); + } else { + GGML_ASSERT(false); + } + work_size = MAX(work_size, cur); } break; case GGML_OP_FLASH_ATTN: @@ -15963,6 +16570,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) work_size = MAX(work_size, cur); } break; + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: { @@ -16495,16 +17104,20 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** if (!*ctx_data) { fprintf(stderr, "%s: failed to create ggml context\n", __func__); + fclose(fin); return result; } } data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize); - const size_t ret = fread(data->data, sizeof(char), fsize, fin); - if (ret != fsize) { - fprintf(stderr, "%s: failed to read %s\n", __func__, fname); - return result; + { + const size_t ret = fread(data->data, sizeof(char), fsize, fin); + if (ret != fsize) { + fprintf(stderr, "%s: failed to read %s\n", __func__, fname); + fclose(fin); + return result; + } } fclose(fin); diff --git a/ggml.h b/ggml.h index 1380c530f..18c78551f 100644 --- a/ggml.h +++ b/ggml.h @@ -303,6 +303,7 @@ extern "C" { GGML_OP_STEP, GGML_OP_RELU, GGML_OP_GELU, + GGML_OP_GELU_QUICK, GGML_OP_SILU, GGML_OP_SILU_BACK, GGML_OP_NORM, // normalize @@ -331,12 +332,15 @@ extern "C" { GGML_OP_ROPE_BACK, GGML_OP_ALIBI, GGML_OP_CLAMP, - GGML_OP_CONV_1D_1S, - GGML_OP_CONV_1D_2S, + GGML_OP_CONV_1D_S1_PH, + GGML_OP_CONV_1D_S2_PH, + GGML_OP_CONV_2D_SK_P0, GGML_OP_FLASH_ATTN, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, + GGML_OP_WIN_PART, + GGML_OP_WIN_UNPART, GGML_OP_MAP_UNARY, GGML_OP_MAP_BINARY, @@ -557,8 +561,8 @@ extern "C" { GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); - GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor); - GGML_API void ggml_set_name(struct ggml_tensor * tensor, const char * name); + GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name); // // operations on tensors with backpropagation @@ -611,24 +615,47 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_sub_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_mul( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_mul_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_div( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_div_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_sqr( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sqrt( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sqrt_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_log( struct ggml_context * ctx, struct ggml_tensor * a); @@ -668,31 +695,67 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_abs_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sgn( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_sgn_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_neg( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_neg_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_step( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_step_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_relu( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_relu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + // TODO: double-check this computation is correct GGML_API struct ggml_tensor * ggml_gelu( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_gelu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_quick_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_silu( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_silu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + // a - x // b - dy GGML_API struct ggml_tensor * ggml_silu_back( @@ -706,10 +769,18 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_rms_norm( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + // a - x // b - dy GGML_API struct ggml_tensor * ggml_rms_norm_back( @@ -999,16 +1070,55 @@ extern "C" { float min, float max); - // padding = 1 + // TODO: implement general-purpose convolutions + // GGML_API struct ggml_tensor * ggml_conv_1d( + // struct ggml_context * ctx, + // struct ggml_tensor * a, + // struct ggml_tensor * b, + // int s0 + // int p0, + // int d0); + // + // GGML_API struct ggml_tensor * ggml_conv_2d( + // struct ggml_context * ctx, + // struct ggml_tensor * a, + // struct ggml_tensor * b, + // int s0, + // int s1, + // int p0, + // int p1, + // int d0, + // int d1); + + // padding = half // TODO: we don't support extra parameters for now // that's why we are hard-coding the stride, padding, and dilation // not great .. - GGML_API struct ggml_tensor * ggml_conv_1d_1s( + // example: + // a: 3 80 768 1 + // b: 3000 80 1 1 + // res: 3000 768 1 1 + // used in whisper + GGML_API struct ggml_tensor * ggml_conv_1d_s1_ph( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); - GGML_API struct ggml_tensor * ggml_conv_1d_2s( + // used in whisper + GGML_API struct ggml_tensor * ggml_conv_1d_s2_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // kernel size is a->ne[0] x a->ne[1] + // stride is equal to kernel size + // padding is zero + // example: + // a: 16 16 3 768 + // b: 1024 1024 3 1 + // res: 64 64 768 1 + // used in sam + GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b); @@ -1036,6 +1146,26 @@ extern "C" { struct ggml_tensor * c0, struct ggml_tensor * c1); + // partition into non-overlapping windows with padding if needed + // example: + // a: 768 64 64 1 + // w: 14 + // res: 768 14 14 25 + // used in sam + GGML_API struct ggml_tensor * ggml_win_part( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w); + + // reverse of ggml_win_part + // used in sam + GGML_API struct ggml_tensor * ggml_win_unpart( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int w); + // Mapping operations typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *); typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);