diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 31d6620a2..79b117df7 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1868,10 +1868,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head); t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd); t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch); - t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch); + t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch); t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch); t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch); - t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch); + t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx, 10000.0f, 1.0f, 0.0f, false)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch); t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch); t04->grad = expand(gb, ggml_add_inplace(ctx0, ggml_add_inplace(ctx0, diff --git a/ggml-alloc.c b/ggml-alloc.c index 3ee98d03d..f06f9a3c1 100644 --- a/ggml-alloc.c +++ b/ggml-alloc.c @@ -76,7 +76,7 @@ struct ggml_allocr { }; #ifdef GGML_ALLOCATOR_DEBUG -static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) { +static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { for (int i = 0; i < 1024; i++) { if (alloc->allocated_tensors[i] == NULL) { alloc->allocated_tensors[i] = tensor; @@ -85,7 +85,7 @@ static void add_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tens } GGML_ASSERT(!"out of allocated_tensors"); } -static void remove_allocated_tensor(struct ggml_allocator * alloc, struct ggml_tensor * tensor) { +static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) { for (int i = 0; i < 1024; i++) { if (alloc->allocated_tensors[i] == tensor || (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) { diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5b415c646..c0fb9fb65 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -259,6 +259,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_SCALE_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256 +#define CUDA_ALIBI_BLOCK_SIZE 32 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 #define CUDA_QUANTIZE_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 @@ -3940,6 +3941,29 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; } +static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows, + const int n_heads_log2_floor, const float m0, const float m1) { + const int col = blockDim.x*blockIdx.x + threadIdx.x; + + if (col >= ncols) { + return; + } + + const int row = blockDim.y*blockIdx.y + threadIdx.y; + const int i = row*ncols + col; + + const int k = row/k_rows; + + float m_k; + if (k < n_heads_log2_floor) { + m_k = powf(m0, k + 1); + } else { + m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); + } + + dst[i] = col * m_k + x[i]; +} + static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { const int col = blockDim.x*blockIdx.x + threadIdx.x; const int row = blockDim.y*blockIdx.y + threadIdx.y; @@ -4766,6 +4790,15 @@ static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, con rope_glm_f32<<>>(x, dst, ncols, p, block_p, theta_scale); } +static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, + const int k_rows, const int n_heads_log2_floor, const float m0, + const float m1, cudaStream_t stream) { + const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1); + const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE); + const dim3 block_nums(num_blocks_x, nrows, 1); + alibi_f32<<>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1); +} + static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1); const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; @@ -5501,6 +5534,41 @@ inline void ggml_cuda_op_rope( (void) i1; } +inline void ggml_cuda_op_alibi( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, + float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, + cudaStream_t & cudaStream_main){ + + GGML_ASSERT(src0_ddf_i != nullptr); + GGML_ASSERT(dst_ddf_i != nullptr); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t i01_diff = i01_high - i01_low; + + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); + + GGML_ASSERT(ne01 + n_past == ne00); + GGML_ASSERT(n_head == ne02); + + const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); + + // compute + alibi_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_heads_log2_floor, m0, m1, cudaStream_main); + + (void) src1; + (void) src0_ddq_i; + (void) src1_ddf_i; + (void) i1; +} + inline void ggml_cuda_op_diag_mask_inf( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, @@ -6121,6 +6189,11 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm } +void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_alibi, true, true); +} + void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { (void) src0; (void) src1; @@ -6456,6 +6529,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ } func = ggml_cuda_rope; break; + case GGML_OP_ALIBI: + if (!any_on_device) { + return false; + } + func = ggml_cuda_alibi; + break; default: return false; } diff --git a/ggml.c b/ggml.c index c917d73c7..dffb97731 100644 --- a/ggml.c +++ b/ggml.c @@ -216,7 +216,6 @@ inline static void * ggml_aligned_malloc(size_t size) { GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0)); return NULL; } - return aligned_memory; } #define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size) @@ -3722,6 +3721,10 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { *s = idx; } +// +// data types +// + static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "NONE", @@ -3741,10 +3744,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARGMAX", "REPEAT", "REPEAT_BACK", + "CONCAT", "SILU_BACK", "NORM", "RMS_NORM", "RMS_NORM_BACK", + "GROUP_NORM", "MUL_MAT", "OUT_PROD", @@ -3770,20 +3775,28 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CLAMP", "CONV_1D", "CONV_2D", + "CONV_TRANSPOSE_2D", "POOL_1D", "POOL_2D", + "UPSCALE", "FLASH_ATTN", "FLASH_FF", "FLASH_ATTN_BACK", "WIN_PART", "WIN_UNPART", + "GET_REL_POS", + "ADD_REL_POS", "UNARY", "MAP_UNARY", "MAP_BINARY", + "MAP_CUSTOM1_F32", + "MAP_CUSTOM2_F32", + "MAP_CUSTOM3_F32", + "MAP_CUSTOM1", "MAP_CUSTOM2", "MAP_CUSTOM3", @@ -3792,7 +3805,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 62, "GGML_OP_COUNT != 62"); +static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3813,10 +3826,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "argmax(x)", "repeat(x)", "repeat_back(x)", + "concat(x, y)", "silu_back(x)", "norm(x)", "rms_norm(x)", "rms_norm_back(x)", + "group_norm(x)", "X*Y", "X*Y", @@ -3842,20 +3857,28 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "clamp(x)", "conv_1d(x)", "conv_2d(x)", + "conv_transpose_2d(x)", "pool_1d(x)", "pool_2d(x)", + "upscale(x)", "flash_attn(x)", "flash_ff(x)", "flash_attn_back(x)", "win_part(x)", "win_unpart(x)", + "get_rel_pos(x)", + "add_rel_pos(x)", "unary(x)", "f(x)", "f(x,y)", + "custom_f32(x)", + "custom_f32(x,y)", + "custom_f32(x,y,z)", + "custom(x)", "custom(x,y)", "custom(x,y,z)", @@ -3864,7 +3887,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 62, "GGML_OP_COUNT != 62"); +static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -3894,8 +3917,10 @@ static void ggml_setup_op_has_task_pass(void) { p[GGML_OP_DIAG_MASK_ZERO ] = true; p[GGML_OP_CONV_1D ] = true; p[GGML_OP_CONV_2D ] = true; + p[GGML_OP_CONV_TRANSPOSE_2D ] = true; p[GGML_OP_FLASH_ATTN_BACK ] = true; p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; + p[GGML_OP_ADD_REL_POS ] = true; } { // FINALIZE @@ -5572,6 +5597,30 @@ struct ggml_tensor * ggml_repeat_back( return result; } +// ggml_concat + +struct ggml_tensor* ggml_concat( + struct ggml_context* ctx, + struct ggml_tensor* a, + struct ggml_tensor* b) { + GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]); + + result->op = GGML_OP_CONCAT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_abs struct ggml_tensor * ggml_abs( @@ -5771,6 +5820,8 @@ struct ggml_tensor * ggml_norm_inplace( return ggml_norm_impl(ctx, a, true); } +// ggml_rms_norm + static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, @@ -5807,6 +5858,8 @@ struct ggml_tensor * ggml_rms_norm_inplace( return ggml_rms_norm_impl(ctx, a, eps, true); } +// ggml_rms_norm_back + struct ggml_tensor * ggml_rms_norm_back( struct ggml_context * ctx, struct ggml_tensor * a, @@ -5828,6 +5881,44 @@ struct ggml_tensor * ggml_rms_norm_back( return result; } +// ggml_group_norm + +static struct ggml_tensor * ggml_group_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + bool inplace) { + + bool is_node = false; + if (!inplace && (a->grad)) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_GROUP_NORM; + result->op_params[0] = n_groups; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = NULL; // TODO: maybe store epsilon here? + + return result; +} + +struct ggml_tensor * ggml_group_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups) { + return ggml_group_norm_impl(ctx, a, n_groups, false); +} + +struct ggml_tensor * ggml_group_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups) { + return ggml_group_norm_impl(ctx, a, n_groups, true); +} // ggml_mul_mat @@ -6696,6 +6787,8 @@ static struct ggml_tensor * ggml_rope_impl( int n_ctx, float freq_base, float freq_scale, + float xpos_base, + bool xpos_down, bool inplace) { GGML_ASSERT(n_past >= 0); bool is_node = false; @@ -6706,9 +6799,11 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[6] = { n_past, n_dims, mode, n_ctx }; + int32_t params[8] = { n_past, n_dims, mode, n_ctx }; memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 5, &freq_scale, sizeof(float)); + memcpy(params + 6, &xpos_base, sizeof(float)); + memcpy(params + 7, &xpos_down, sizeof(bool)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE; @@ -6725,7 +6820,7 @@ struct ggml_tensor * ggml_rope( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); } struct ggml_tensor * ggml_rope_inplace( @@ -6735,7 +6830,7 @@ struct ggml_tensor * ggml_rope_inplace( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); } struct ggml_tensor * ggml_rope_custom( @@ -6747,7 +6842,7 @@ struct ggml_tensor * ggml_rope_custom( int n_ctx, float freq_base, float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, false); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); } struct ggml_tensor * ggml_rope_custom_inplace( @@ -6759,7 +6854,17 @@ struct ggml_tensor * ggml_rope_custom_inplace( int n_ctx, float freq_base, float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); +} + +struct ggml_tensor * ggml_rope_xpos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + float base, + bool down) { + return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); } // ggml_rope_back @@ -6770,7 +6875,11 @@ struct ggml_tensor * ggml_rope_back( int n_past, int n_dims, int mode, - int n_ctx) { + int n_ctx, + float freq_base, + float freq_scale, + float xpos_base, + bool xpos_down) { GGML_ASSERT(n_past >= 0); GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); @@ -6782,7 +6891,11 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - int32_t params[] = { n_past, n_dims, mode, n_ctx }; + int32_t params[8] = { n_past, n_dims, mode, n_ctx }; + memcpy(params + 4, &freq_base, sizeof(float)); + memcpy(params + 5, &freq_scale, sizeof(float)); + memcpy(params + 6, &xpos_base, sizeof(float)); + memcpy(params + 7, &xpos_down, sizeof(bool)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE_BACK; @@ -6889,6 +7002,17 @@ GGML_API struct ggml_tensor * ggml_conv_1d( return result; } +// ggml_conv_1d_ph + +struct ggml_tensor* ggml_conv_1d_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s, + int d) { + return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); +} + // ggml_conv_2d struct ggml_tensor * ggml_conv_2d( @@ -6929,17 +7053,59 @@ struct ggml_tensor * ggml_conv_2d( } -// ggml_conv_1d_ph +// ggml_conv_2d_sk_p0 -struct ggml_tensor * ggml_conv_1d_ph( +struct ggml_tensor * ggml_conv_2d_sk_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1); +} + +// ggml_conv_2d_s1_ph + +struct ggml_tensor * ggml_conv_2d_s1_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1); +} + +// ggml_conv_transpose_2d_p0 + +static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { + return (ins - 1) * s - 2 * p + ks; +} + +struct ggml_tensor * ggml_conv_transpose_2d_p0( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - int s, - int d) { - return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); -} + int stride) { + GGML_ASSERT(a->ne[3] == b->ne[2]); + bool is_node = false; + + if (a->grad || b->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/), + ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/), + a->ne[2], b->ne[3], + }; + + struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + result->op = GGML_OP_CONV_TRANSPOSE_2D; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + result->src[2] = ggml_new_i32(ctx, stride); + + return result; +} // ggml_pool_* @@ -7017,6 +7183,40 @@ struct ggml_tensor * ggml_pool_2d( return result; } +// ggml_upscale + +static struct ggml_tensor * ggml_upscale_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor) { + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, + a->ne[0] * scale_factor, + a->ne[1] * scale_factor, + a->ne[2], a->ne[3]); + + result->op = GGML_OP_UPSCALE; + result->op_params[0] = scale_factor; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = NULL; + + return result; +} + +struct ggml_tensor * ggml_upscale( + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor) { + return ggml_upscale_impl(ctx, a, scale_factor); +} + // ggml_flash_attn struct ggml_tensor * ggml_flash_attn( @@ -7215,6 +7415,87 @@ struct ggml_tensor * ggml_win_unpart( return result; } +// ggml_get_rel_pos + +struct ggml_tensor * ggml_get_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + int qh, + int kh) { + GGML_ASSERT(qh == kh); + GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]); + + bool is_node = false; + + if (a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne); + + result->op = GGML_OP_GET_REL_POS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = NULL; + + return result; +} + +// ggml_add_rel_pos + +static struct ggml_tensor * ggml_add_rel_pos_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph, + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(pw, ph)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(ggml_is_contiguous(pw)); + GGML_ASSERT(ggml_is_contiguous(ph)); + GGML_ASSERT(ph->type == GGML_TYPE_F32); + GGML_ASSERT(pw->type == GGML_TYPE_F32); + GGML_ASSERT(pw->ne[3] == a->ne[2]); + GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]); + GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]); + + bool is_node = false; + + if (!inplace && (a->grad || pw->grad || ph->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + ggml_set_op_params_i32(result, 0, inplace ? 1 : 0); + + result->op = GGML_OP_ADD_REL_POS; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = pw; + result->src[2] = ph; + + return result; +} + + +struct ggml_tensor * ggml_add_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph) { + return ggml_add_rel_pos_impl(ctx, a, pw, ph, false); +} + +struct ggml_tensor * ggml_add_rel_pos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph) { + return ggml_add_rel_pos_impl(ctx, a, pw, ph, true); +} + // gmml_unary static struct ggml_tensor * ggml_unary_impl( @@ -9718,6 +9999,72 @@ static void ggml_compute_forward_repeat_back( } } +// ggml_compute_forward_concat + +static void ggml_compute_forward_concat_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + + GGML_TENSOR_BINARY_OP_LOCALS; + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2++) { + if (i2 < ne02) { // src0 + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03); + + float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); + *y = *x; + } + } + } // src1 + else { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13); + + float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); + *y = *x; + } + } + } + } + } +} + +static void ggml_compute_forward_concat( + const struct ggml_compute_params* params, + const struct ggml_tensor* src0, + const struct ggml_tensor* src1, + struct ggml_tensor* dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_concat_f32(params, src0, src1, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_abs static void ggml_compute_forward_abs_f32( @@ -10321,6 +10668,8 @@ static void ggml_compute_forward_norm( } } +// ggml_compute_forward_group_rms_norm + static void ggml_compute_forward_rms_norm_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -10385,7 +10734,6 @@ static void ggml_compute_forward_rms_norm( } } - static void ggml_compute_forward_rms_norm_back_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -10559,6 +10907,96 @@ static void ggml_compute_forward_rms_norm_back( } } +// ggml_compute_forward_group_norm + +static void ggml_compute_forward_group_norm_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS; + + const float eps = 1e-6f; // TODO: make this a parameter + + // TODO: optimize + + int n_channels = src0->ne[2]; + int n_groups = dst->op_params[0]; + int n_channels_per_group = (n_channels + n_groups - 1) / n_groups; + for (int i = ith; i < n_groups; i+=nth) { + int start = i * n_channels_per_group; + int end = start + n_channels_per_group; + if (end > n_channels) { + end = n_channels; + } + int step = end - start; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + ggml_float sum = 0.0; + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)x[i00]; + } + } + } + float mean = sum / (ne00 * ne01 * step); + ggml_float sum2 = 0.0; + + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + + float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); + + for (int64_t i00 = 0; i00 < ne00; i00++) { + float v = x[i00] - mean; + y[i00] = v; + sum2 += (ggml_float)(v * v); + } + } + } + float variance = sum2 / (ne00 * ne01 * step); + const float scale = 1.0f / sqrtf(variance + eps); + + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); + ggml_vec_scale_f32(ne00, y, scale); + } + } + } + } +} + +static void ggml_compute_forward_group_norm( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_group_norm_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_mul_mat #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) @@ -10625,6 +11063,10 @@ static void ggml_compute_forward_mul_mat( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + // nb01 >= nb00 - src0 is not transposed // compute by src0 rows @@ -10644,11 +11086,6 @@ static void ggml_compute_forward_mul_mat( #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { - // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension - // ref: https://github.com/ggerganov/ggml/pull/224 - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - if (params->ith != 0) { return; } @@ -10661,12 +11098,16 @@ static void ggml_compute_forward_mul_mat( return; } - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - const void * x = (char *) src0->data + i03*nb03 + i02*nb02; - const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13); + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + // broadcast src0 into src1 across 2nd,3rd dimension + const int64_t i03 = i13/r3; + const int64_t i02 = i12/r2; - float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + const void * x = (char *) src0->data + i02*nb02 + i03*nb03; + const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); + + float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); if (type != GGML_TYPE_F32) { float * const wdata = params->wdata; @@ -10674,7 +11115,7 @@ static void ggml_compute_forward_mul_mat( size_t id = 0; for (int64_t i01 = 0; i01 < ne01; ++i01) { - to_float((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); + to_float((const char *) x + i01*nb01, wdata + id, ne00); id += ne00; } @@ -10754,10 +11195,6 @@ static void ggml_compute_forward_mul_mat( assert(ne12 % ne02 == 0); assert(ne13 % ne03 == 0); - // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; - // block-tiling attempt const int64_t blck_0 = 16; const int64_t blck_1 = 16; @@ -11913,7 +12350,6 @@ static void ggml_compute_forward_alibi( } } - // ggml_compute_forward_clamp static void ggml_compute_forward_clamp_f32( @@ -12002,12 +12438,18 @@ static void ggml_compute_forward_rope_f32( float freq_base; float freq_scale; + // these two only relevant for xPos RoPE: + float xpos_base; + bool xpos_down; + const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); assert(n_past >= 0); @@ -12079,6 +12521,9 @@ static void ggml_compute_forward_rope_f32( for (int64_t i0 = 0; i0 < ne0; i0 += 2) { const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); + // zeta scaling for xPos only: + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + if (xpos_down) zeta = 1.0f / zeta; theta *= theta_scale; @@ -12088,8 +12533,8 @@ static void ggml_compute_forward_rope_f32( const float x0 = src[0]; const float x1 = src[1]; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; + dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta; + dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; } } else { // TODO: this is probably wrong, but I can't figure it out .. @@ -12283,9 +12728,21 @@ static void ggml_compute_forward_rope_back_f32( // dx = rope_back(dy, src1) // src0 is dy, src1 contains options + float freq_base; + float freq_scale; + + // these two only relevant for xPos RoPE: + float xpos_base; + bool xpos_down; + const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx); + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); assert(n_past >= 0); @@ -12311,7 +12768,7 @@ static void ggml_compute_forward_rope_back_f32( // row index used to determine which thread to use int ir = 0; - const float theta_scale = powf(10000.0, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f/n_dims); const bool is_neox = mode & 2; @@ -12322,12 +12779,15 @@ static void ggml_compute_forward_rope_back_f32( if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = (float)p; + float theta = freq_scale * (float)p; if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); + // zeta scaling for xPos only: + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + if (xpos_down) zeta = 1.0f / zeta; theta *= theta_scale; @@ -12337,8 +12797,8 @@ static void ggml_compute_forward_rope_back_f32( const float dy0 = dy[0]; const float dy1 = dy[1]; - dx[0] = dy0*cos_theta + dy1*sin_theta; - dx[1] = - dy0*sin_theta + dy1*cos_theta; + dx[0] = dy0*cos_theta*zeta + dy1*sin_theta*zeta; + dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta; } } else { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { @@ -13031,6 +13491,108 @@ static void ggml_compute_forward_conv_2d( } } +// ggml_compute_forward_conv_transpose_2d + +static void ggml_compute_forward_conv_transpose_2d( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * opt0, + struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00*ne01*ne02*ne03; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == GGML_TASK_INIT) { + memset(params->wdata, 0, params->wsize); + + // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02); + ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03; + for (int64_t i01 = 0; i01 < ne01; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00]; + } + } + } + } + } + + // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + for (int i12 = 0; i12 < ne12; i12++) { + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11); + ggml_fp16_t * dst_data = wdata + i11*ne10*ne12; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { + return; + } + + const int32_t stride = ((const int32_t*)(opt0->data))[0]; + + // total patches in dst + const int np = ne2; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + ggml_fp16_t * const wdata_src = (ggml_fp16_t *) params->wdata + nk; + + for (int i2 = ip0; i2 < ip1; i2++) { // Cout + float * dst_data = (float *)((char *) dst->data + i2*nb2); + ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; + for (int i11 = 0; i11 < ne11; i11++) { + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i11*ne10*ne12 + i10*ne12; + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f16(ne03, &v, + (ggml_fp16_t *) wdata_src + i1n, + (ggml_fp16_t *) wdata_kernel + i01*ne00*ne03 + i00*ne03); + + dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; + } + } + } + } + } +} + // ggml_compute_forward_pool_1d_sk_p0 static void ggml_compute_forward_pool_1d_sk_p0( @@ -13189,6 +13751,60 @@ static void ggml_compute_forward_pool_2d( ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst); } +// ggml_compute_forward_upscale + +static void ggml_compute_forward_upscale_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + + GGML_TENSOR_UNARY_OP_LOCALS; + + const int scale_factor = dst->op_params[0]; + + // TODO: optimize + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = ith; i02 < ne02; i02++) { + for (int m = 0; m < dst->ne[1]; m++) { + int i01 = m / scale_factor; + for (int n = 0; n < dst->ne[0]; n++) { + int i00 = n / scale_factor; + + const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03); + + float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]); + + *y = *x; + } + } + } + } +} + +static void ggml_compute_forward_upscale( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_upscale_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} // ggml_compute_forward_flash_attn @@ -14314,6 +14930,137 @@ static void ggml_compute_forward_unary( } } +// ggml_compute_forward_get_rel_pos + +static void ggml_compute_forward_get_rel_pos_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 + + GGML_TENSOR_UNARY_OP_LOCALS; + + const int64_t w = ne1; + + ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data; + ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data; + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + const int64_t pos = (w - i1 - 1) + i2; + for (int64_t i0 = 0; i0 < ne0; ++i0) { + dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; + } + } + } +} + +static void ggml_compute_forward_get_rel_pos( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_get_rel_pos_f16(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +// ggml_compute_forward_add_rel_pos + +static void ggml_compute_forward_add_rel_pos_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * src2, + struct ggml_tensor * dst) { + + const bool inplace = (bool) ((int32_t *) dst->op_params)[0]; + if (!inplace && params->type == GGML_TASK_INIT) { + memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst)); + return; + } + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359 + + float * src1_data = (float *) src1->data; + float * src2_data = (float *) src2->data; + float * dst_data = (float *) dst->data; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int ith = params->ith; + const int nth = params->nth; + + // total patches in dst + const int np = ne13; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + + for (int64_t i13 = ip0; i13 < ip1; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10; + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t jp0 = jp1 + i10; + const float src1_e = src1_data[jp0]; + const float src2_e = src2_data[jp0]; + + const int64_t jdh = jp0 * ne10; + const int64_t jdw = jdh - (ne10 - 1) * i10; + + for (int64_t j = 0; j < ne10; ++j) { + dst_data[jdh + j ] += src2_e; + dst_data[jdw + j*ne10] += src1_e; + } + } + } + } + } +} + +static void ggml_compute_forward_add_rel_pos( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * src2, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add_rel_pos_f32(params, src0, src1, src2, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_map_unary static void ggml_compute_forward_map_unary_f32( @@ -14866,6 +15613,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_repeat_back(params, tensor->src[0], tensor); } break; + case GGML_OP_CONCAT: + { + ggml_compute_forward_concat(params, tensor->src[0], tensor->src[1], tensor); + } break; case GGML_OP_SILU_BACK: { ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor); @@ -14882,6 +15633,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rms_norm_back(params, tensor->src[0], tensor->src[1], tensor); } break; + case GGML_OP_GROUP_NORM: + { + ggml_compute_forward_group_norm(params, tensor->src[0], tensor); + } break; case GGML_OP_MUL_MAT: { ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor); @@ -14974,6 +15729,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor); } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } break; case GGML_OP_POOL_1D: { ggml_compute_forward_pool_1d(params, tensor->src[0], tensor); @@ -14982,6 +15741,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_pool_2d(params, tensor->src[0], tensor); } break; + case GGML_OP_UPSCALE: + { + ggml_compute_forward_upscale(params, tensor->src[0], tensor); + } break; case GGML_OP_FLASH_ATTN: { const int32_t t = ggml_get_op_params_i32(tensor, 0); @@ -15012,6 +15775,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_unary(params, tensor->src[0], tensor); } break; + case GGML_OP_GET_REL_POS: + { + ggml_compute_forward_get_rel_pos(params, tensor->src[0], tensor); + } break; + case GGML_OP_ADD_REL_POS: + { + ggml_compute_forward_add_rel_pos(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } break; case GGML_OP_MAP_UNARY: { ggml_unary_op_f32_t fun; @@ -15275,6 +16046,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor inplace); } } break; + case GGML_OP_CONCAT: + { + GGML_ASSERT(false); // TODO: implement + } break; case GGML_OP_SILU_BACK: { GGML_ASSERT(false); // TODO: not implemented @@ -15297,6 +16072,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_GROUP_NORM: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_MUL_MAT: { // https://cs231n.github.io/optimization-2/#staged @@ -15571,6 +16350,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; + float freq_base; + float freq_scale; + float xpos_base; + bool xpos_down; + memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float)); + memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); + memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); + src0->grad = ggml_add_impl(ctx, src0->grad, ggml_rope_back(ctx, @@ -15578,7 +16366,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor n_past, n_dims, mode, - n_ctx), + n_ctx, + freq_base, + freq_scale, + xpos_base, + xpos_down), inplace); } } break; @@ -15589,14 +16381,28 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; + float freq_base; + float freq_scale; + float xpos_base; + bool xpos_down; + memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float)); + memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); + memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); + src0->grad = ggml_add_impl(ctx, src0->grad, - ggml_rope(ctx, + ggml_rope_impl(ctx, tensor->grad, n_past, n_dims, mode, - n_ctx), + n_ctx, + freq_base, + freq_scale, + xpos_base, + xpos_down, + false), inplace); } } break; @@ -15616,6 +16422,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_POOL_1D: { GGML_ASSERT(false); // TODO: not implemented @@ -15624,6 +16434,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_UPSCALE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_FLASH_ATTN: { struct ggml_tensor * flash_grad = NULL; @@ -15865,6 +16679,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ASSERT(false); } } break; + case GGML_OP_GET_REL_POS: + case GGML_OP_ADD_REL_POS: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: @@ -16441,9 +17257,11 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_GROUP_NORM: { n_tasks = n_threads; } break; + case GGML_OP_CONCAT: case GGML_OP_MUL_MAT: case GGML_OP_OUT_PROD: { @@ -16511,6 +17329,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: + case GGML_OP_ADD_REL_POS: { n_tasks = n_threads; } break; @@ -16585,6 +17404,25 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { GGML_ASSERT(false); } + work_size = MAX(work_size, cur); + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + n_tasks = n_threads; + + const int64_t ne00 = node->src[0]->ne[0]; // W + const int64_t ne01 = node->src[0]->ne[1]; // H + const int64_t ne02 = node->src[0]->ne[2]; // Channels Out + const int64_t ne03 = node->src[0]->ne[3]; // Channels In + + const int64_t ne10 = node->src[1]->ne[0]; // W + const int64_t ne11 = node->src[1]->ne[1]; // H + const int64_t ne12 = node->src[1]->ne[2]; // Channels In + + size_t cur = 0; + cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; + cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; + work_size = MAX(work_size, cur); } break; case GGML_OP_POOL_1D: @@ -16592,6 +17430,10 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { { n_tasks = 1; } break; + case GGML_OP_UPSCALE: + { + n_tasks = n_threads; + } break; case GGML_OP_FLASH_ATTN: { n_tasks = n_threads; @@ -16653,6 +17495,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: + case GGML_OP_GET_REL_POS: case GGML_OP_MAP_UNARY: case GGML_OP_MAP_BINARY: case GGML_OP_MAP_CUSTOM1_F32: @@ -16770,8 +17613,10 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); GGML_ASSERT(rc == 0); + UNUSED(rc); } } + workers[0].ith = 0; workers[0].shared = &state_shared; diff --git a/ggml.h b/ggml.h index 0ec7ec5bf..3c48fd27f 100644 --- a/ggml.h +++ b/ggml.h @@ -211,6 +211,7 @@ #define GGML_MAX_OP_PARAMS 32 #define GGML_DEFAULT_N_THREADS 4 + #define GGML_EXIT_SUCCESS 0 #define GGML_EXIT_ABORTED 1 @@ -345,10 +346,12 @@ extern "C" { GGML_OP_ARGMAX, GGML_OP_REPEAT, GGML_OP_REPEAT_BACK, + GGML_OP_CONCAT, GGML_OP_SILU_BACK, GGML_OP_NORM, // normalize GGML_OP_RMS_NORM, GGML_OP_RMS_NORM_BACK, + GGML_OP_GROUP_NORM, GGML_OP_MUL_MAT, GGML_OP_OUT_PROD, @@ -374,14 +377,19 @@ extern "C" { GGML_OP_CLAMP, GGML_OP_CONV_1D, GGML_OP_CONV_2D, + GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, GGML_OP_POOL_2D, + GGML_OP_UPSCALE, // nearest interpolate + GGML_OP_FLASH_ATTN, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_WIN_PART, GGML_OP_WIN_UNPART, + GGML_OP_GET_REL_POS, + GGML_OP_ADD_REL_POS, GGML_OP_UNARY, @@ -805,6 +813,13 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // concat a and b on dim 2 + // used in stable-diffusion + GGML_API struct ggml_tensor * ggml_concat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_abs( struct ggml_context * ctx, struct ggml_tensor * a); @@ -913,6 +928,19 @@ extern "C" { struct ggml_tensor * a, float eps); + // group normalize along ne0*ne1*n_groups + // used in stable-diffusion + // TODO: eps is hardcoded to 1e-6 for now + GGML_API struct ggml_tensor * ggml_group_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups); + + GGML_API struct ggml_tensor * ggml_group_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups); + // a - x // b - dy // TODO: update with configurable eps @@ -1213,6 +1241,15 @@ extern "C" { float freq_base, float freq_scale); + // xPos RoPE, in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + float base, + bool down); + // rotary position embedding backward, i.e compute dx from dy // a - dy GGML_API struct ggml_tensor * ggml_rope_back( @@ -1221,7 +1258,11 @@ extern "C" { int n_past, int n_dims, int mode, - int n_ctx); + int n_ctx, + float freq_base, + float freq_scale, + float xpos_base, + bool xpos_down); // alibi position embedding // in-place, returns view(a) @@ -1248,6 +1289,15 @@ extern "C" { int p0, // padding int d0); // dilation + // conv_1d with padding = half + // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d) + GGML_API struct ggml_tensor* ggml_conv_1d_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s, + int d); + GGML_API struct ggml_tensor * ggml_conv_2d( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1259,14 +1309,38 @@ extern "C" { int d0, int d1); - // conv_1d with padding = half - // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d) - GGML_API struct ggml_tensor * ggml_conv_1d_ph( + + // kernel size is a->ne[0] x a->ne[1] + // stride is equal to kernel size + // padding is zero + // example: + // a: 16 16 3 768 + // b: 1024 1024 3 1 + // res: 64 64 768 1 + // used in sam + GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // kernel size is a->ne[0] x a->ne[1] + // stride is 1 + // padding is half + // example: + // a: 3 3 256 256 + // b: 64 64 256 1 + // res: 64 64 256 1 + // used in sam + GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - int s, - int d); + int stride); enum ggml_op_pool { GGML_OP_POOL_MAX, @@ -1293,6 +1367,13 @@ extern "C" { int p0, int p1); + // nearest interpolate + // used in stable-diffusion + GGML_API struct ggml_tensor * ggml_upscale( + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor); + GGML_API struct ggml_tensor * ggml_flash_attn( struct ggml_context * ctx, struct ggml_tensor * q, @@ -1346,6 +1427,27 @@ extern "C" { struct ggml_tensor * a, enum ggml_unary_op op); + // used in sam + GGML_API struct ggml_tensor * ggml_get_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + int qh, + int kh); + + // used in sam + + GGML_API struct ggml_tensor * ggml_add_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph); + + GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph); + // custom operators typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); @@ -1500,6 +1602,7 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh index 3d13e852a..e44c3bd03 100755 --- a/scripts/sync-ggml.sh +++ b/scripts/sync-ggml.sh @@ -1,14 +1,16 @@ #!/bin/bash -cp -rpv ../ggml/src/ggml.c ./ggml.c -cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h -cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu -cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h -cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp -cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h -cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m -cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal -cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h +cp -rpv ../ggml/src/ggml.c ./ggml.c +cp -rpv ../ggml/src/ggml-alloc.c ./ggml-alloc.c +cp -rpv ../ggml/src/ggml-cuda.h ./ggml-cuda.h +cp -rpv ../ggml/src/ggml-cuda.cu ./ggml-cuda.cu +cp -rpv ../ggml/src/ggml-opencl.h ./ggml-opencl.h +cp -rpv ../ggml/src/ggml-opencl.cpp ./ggml-opencl.cpp +cp -rpv ../ggml/src/ggml-metal.h ./ggml-metal.h +cp -rpv ../ggml/src/ggml-metal.m ./ggml-metal.m +cp -rpv ../ggml/src/ggml-metal.metal ./ggml-metal.metal +cp -rpv ../ggml/include/ggml/ggml.h ./ggml.h +cp -rpv ../ggml/include/ggml/ggml-alloc.h ./ggml-alloc.h cp -rpv ../ggml/tests/test-opt.cpp ./tests/test-opt.cpp cp -rpv ../ggml/tests/test-grad0.cpp ./tests/test-grad0.cpp