From 6eac39ba953acaeec396cea2969dbf413907e2ec Mon Sep 17 00:00:00 2001 From: hoangmit Date: Wed, 15 Mar 2023 18:41:38 -0400 Subject: [PATCH] Add RMS norm and use it (#187) * add ggml_rms_norm * update op num --- ggml.c | 128 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- ggml.h | 5 +++ main.cpp | 6 +-- 3 files changed, 134 insertions(+), 5 deletions(-) diff --git a/ggml.c b/ggml.c index a0c0dd03b..eee54f7ff 100644 --- a/ggml.c +++ b/ggml.c @@ -2069,6 +2069,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "GELU", "SILU", "NORM", + "RMS_NORM", "MUL_MAT", @@ -2089,7 +2090,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "FLASH_FF", }; -static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34"); +static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2112,6 +2113,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "gelu(x)", "silu(x)", "norm(x)", + "rms_norm(x)", "X*Y", @@ -2132,7 +2134,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_ff(x)", }; -static_assert(GGML_OP_COUNT == 34, "GGML_OP_COUNT != 34"); +static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); // // ggml object @@ -3618,6 +3620,39 @@ struct ggml_tensor * ggml_norm_inplace( return ggml_norm_impl(ctx, a, true); } +struct ggml_tensor * ggml_rms_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && (a->grad)) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_RMS_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; // TODO: maybe store epsilon here? + + return result; +} + +struct ggml_tensor * ggml_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_rms_norm_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_rms_norm_impl(ctx, a, true); +} + // ggml_mul_mat struct ggml_tensor * ggml_mul_mat( @@ -5406,6 +5441,87 @@ static void ggml_compute_forward_norm( } } +static void ggml_compute_forward_rms_norm_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + const size_t nb03 = src0->nb[3]; + + const size_t nb1 = dst->nb[1]; + const size_t nb2 = dst->nb[2]; + const size_t nb3 = dst->nb[3]; + + const ggml_float eps = 1e-5f; // TODO: make this a parameter + + // TODO: optimize + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float mean = 0.0; + for (int i00 = 0; i00 < ne00; i00++) { + mean += x[i00] * x[i00]; + } + + mean /= ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + memcpy(y, x, ne00 * sizeof(float)); + // for (int i00 = 0; i00 < ne00; i00++) { + // y[i00] = x[i00]; + // } + + const float scale = 1.0/sqrt(mean + eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_rms_norm( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rms_norm_f32(params, src0, dst); + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_F16: + case GGML_TYPE_COUNT: + { + GGML_ASSERT(false); + } break; + } +} + + // ggml_compute_forward_mul_mat #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) @@ -8522,6 +8638,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_norm(params, tensor->src0, tensor); } break; + case GGML_OP_RMS_NORM: + { + ggml_compute_forward_rms_norm(params, tensor->src0, tensor); + } break; case GGML_OP_MUL_MAT: { ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor); @@ -8764,6 +8884,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_RMS_NORM: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_MUL_MAT: { if (src0->grad) { diff --git a/ggml.h b/ggml.h index 7ce655c1b..bac4fe65c 100644 --- a/ggml.h +++ b/ggml.h @@ -230,6 +230,7 @@ enum ggml_op { GGML_OP_GELU, GGML_OP_SILU, GGML_OP_NORM, // normalize + GGML_OP_RMS_NORM, GGML_OP_MUL_MAT, @@ -482,6 +483,10 @@ struct ggml_tensor * ggml_norm( struct ggml_context * ctx, struct ggml_tensor * a); +struct ggml_tensor * ggml_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a); + // A: m rows, n columns // B: p rows, n columns (i.e. we transpose it internally) // result is m columns, p rows diff --git a/main.cpp b/main.cpp index a812d0fa0..ca0fca8b3 100644 --- a/main.cpp +++ b/main.cpp @@ -588,7 +588,7 @@ bool llama_eval( // norm { - cur = ggml_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL); // cur = attention_norm*cur cur = ggml_mul(ctx0, @@ -678,7 +678,7 @@ bool llama_eval( { // norm { - cur = ggml_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF); // cur = ffn_norm*cur cur = ggml_mul(ctx0, @@ -713,7 +713,7 @@ bool llama_eval( // norm { - inpL = ggml_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL); // inpL = norm*inpL inpL = ggml_mul(ctx0,