From 50337961a678fce4081554b24e56e86b67660163 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 1 Nov 2023 20:11:02 +0200 Subject: [PATCH] llm : add llm_build_context (#3881) * llm : add llm_build_context * llm : deduce norm eps based on type + explict max_alibi_bias, clamp_kqv * llm : restore the non-graph llm_build_ functional API ggml-ci * llm : cleanup + comments --- llama.cpp | 2338 ++++++++++++++++++++++++----------------------------- 1 file changed, 1042 insertions(+), 1296 deletions(-) diff --git a/llama.cpp b/llama.cpp index 42cedc7a1..d0c4ef101 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3090,6 +3090,10 @@ static bool llama_model_load( return true; } +// +// llm_build +// + using llm_build_cb = std::function; enum llm_rope_type { @@ -3098,17 +3102,35 @@ enum llm_rope_type { LLM_ROPE_GLM, }; +enum llm_ffn_op_type { + LLM_FFN_SILU, + LLM_FFN_GELU, + LLM_FFN_RELU, + LLM_FFN_RELU_SQR, +}; + +enum llm_ffn_gate_type { + LLM_FFN_SEQ, + LLM_FFN_PAR, // ffn_gate is parallel to ffn_up +}; + +enum llm_norm_type { + LLM_NORM, + LLM_NORM_RMS, +}; + static struct ggml_tensor * llm_build_inp_embd( struct ggml_context * ctx, + const llama_hparams & hparams, const llama_batch & batch, struct ggml_tensor * tok_embd, - int64_t n_embd, - int32_t n_tokens, const llm_build_cb & cb) { + const int64_t n_embd = hparams.n_embd; + struct ggml_tensor * inpL; if (batch.token) { - struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens); cb(inp_tokens, "inp_tokens", -1); inpL = ggml_get_rows(ctx, tok_embd, inp_tokens); @@ -3117,7 +3139,7 @@ static struct ggml_tensor * llm_build_inp_embd( GGML_ASSERT(false && "not implemented"); #endif - inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens); + inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); } return inpL; @@ -3126,28 +3148,21 @@ static struct ggml_tensor * llm_build_inp_embd( // Persimmon: n_rot = n_embd_head/2 // Other: n_rot = n_embd_head static void llm_build_k_shift( - const llama_context & lctx, - struct ggml_context * ctx, - struct ggml_cgraph * graph, - int64_t n_rot, - llm_rope_type type, - const llm_build_cb & cb) { - const auto & model = lctx.model; - const auto & kv_self = lctx.kv_self; - const auto & cparams = lctx.cparams; - - const auto & hparams = model.hparams; - + struct ggml_context * ctx, + const llama_hparams & hparams, + const llama_kv_cache & kv, + struct ggml_cgraph * graph, + llm_rope_type type, + int64_t n_ctx, + int64_t n_rot, + float freq_base, + float freq_scale, + const llm_build_cb & cb) { const int64_t n_layer = hparams.n_layer; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_embd_gqa = hparams.n_embd_gqa(); const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_ctx = lctx.cparams.n_ctx; - - const float freq_base = cparams.rope_freq_base; - const float freq_scale = cparams.rope_freq_scale; - GGML_ASSERT(n_embd_head % n_rot == 0); struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx); @@ -3165,11 +3180,11 @@ static void llm_build_k_shift( struct ggml_tensor * tmp = // we rotate only the first n_rot dimensions ggml_rope_custom_inplace(ctx, - ggml_view_3d(ctx, kv_self.k, + ggml_view_3d(ctx, kv.k, n_rot, n_head_kv, n_ctx, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), + ggml_element_size(kv.k)*n_embd_head, + ggml_element_size(kv.k)*n_embd_gqa, + ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il), K_shift, n_rot, rope_type, 0, freq_base, freq_scale); cb(tmp, "K_shifted", il); ggml_build_forward_expand(graph, tmp); @@ -3177,22 +3192,17 @@ static void llm_build_k_shift( } static void llm_build_kv_store( - const llama_context & lctx, struct ggml_context * ctx, + const llama_hparams & hparams, + const llama_kv_cache & kv, struct ggml_cgraph * graph, struct ggml_tensor * k_cur, struct ggml_tensor * v_cur, + int64_t n_ctx, int32_t n_tokens, int32_t kv_head, const llm_build_cb & cb, int64_t il) { - const auto & model = lctx.model; - const auto & kv_self = lctx.kv_self; - const auto & cparams = lctx.cparams; - - const auto & hparams = model.hparams; - - const int64_t n_ctx = cparams.n_ctx; const int64_t n_embd_gqa = hparams.n_embd_gqa(); // compute the transposed [n_tokens, n_embd] V matrix @@ -3200,13 +3210,13 @@ static void llm_build_kv_store( //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv_self.k, n_tokens*n_embd_gqa, - (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); + struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k, n_tokens*n_embd_gqa, + (ggml_element_size(kv.k)*n_embd_gqa)*(il*n_ctx + kv_head)); cb(k_cache_view, "k_cache_view", il); - struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv_self.v, n_tokens, n_embd_gqa, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v, n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv.v), + (il*n_ctx)*ggml_element_size(kv.v)*n_embd_gqa + kv_head*ggml_element_size(kv.v)); cb(v_cache_view, "v_cache_view", il); // important: storing RoPE-ed version of K in the KV cache! @@ -3214,23 +3224,18 @@ static void llm_build_kv_store( ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); } -enum llm_norm_type { - LLM_NORM, - LLM_NORM_RMS, -}; - static struct ggml_tensor * llm_build_norm( struct ggml_context * ctx, struct ggml_tensor * cur, + const llama_hparams & hparams, struct ggml_tensor * mw, struct ggml_tensor * mb, llm_norm_type type, - float eps, const llm_build_cb & cb, int il) { switch (type) { - case LLM_NORM: cur = ggml_norm (ctx, cur, eps); break; - case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, eps); break; + case LLM_NORM: cur = ggml_norm (ctx, cur, hparams.f_norm_eps); break; + case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hparams.f_norm_rms_eps); break; } if (mw || mb) { @@ -3251,18 +3256,6 @@ static struct ggml_tensor * llm_build_norm( return cur; } -enum llm_ffn_op_type { - LLM_FFN_SILU, - LLM_FFN_GELU, - LLM_FFN_RELU, - LLM_FFN_RELU_SQR, -}; - -enum llm_ffn_gate_type { - LLM_FFN_SEQ, - LLM_FFN_PAR, // ffn_gate is parallel to ffn_up -}; - static struct ggml_tensor * llm_build_ffn( struct ggml_context * ctx, struct ggml_tensor * cur, @@ -3351,26 +3344,21 @@ static struct ggml_tensor * llm_build_ffn( // if max_alibi_bias > 0 then apply ALiBi static struct ggml_tensor * llm_build_kqv( - const llama_context & lctx, struct ggml_context * ctx, struct ggml_tensor * cur, + const llama_hparams & hparams, + const llama_kv_cache & kv, struct ggml_tensor * wo, struct ggml_tensor * wo_b, struct ggml_tensor * q_cur, struct ggml_tensor * kq_scale, struct ggml_tensor * kq_mask, + int64_t n_ctx, int32_t n_tokens, int32_t n_kv, - float alibi_bias_max, + float max_alibi_bias, const llm_build_cb & cb, - int il) { - const auto & model = lctx.model; - const auto & kv_self = lctx.kv_self; - const auto & cparams = lctx.cparams; - - const auto & hparams = model.hparams; - - const int64_t n_ctx = cparams.n_ctx; + int il) { const int64_t n_embd = hparams.n_embd; const int64_t n_head = hparams.n_head; const int64_t n_head_kv = hparams.n_head_kv; @@ -3381,11 +3369,11 @@ static struct ggml_tensor * llm_build_kqv( cb(q, "q", il); struct ggml_tensor * k = - ggml_view_3d(ctx, kv_self.k, + ggml_view_3d(ctx, kv.k, n_embd_head, n_kv, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + ggml_element_size(kv.k)*n_embd_gqa, + ggml_element_size(kv.k)*n_embd_head, + ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il); cb(k, "k", il); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); @@ -3394,11 +3382,11 @@ static struct ggml_tensor * llm_build_kqv( kq = ggml_scale(ctx, kq, kq_scale); cb(kq, "kq_scaled", il); - if (alibi_bias_max > 0.0f) { + if (max_alibi_bias > 0.0f) { // TODO: n_head or n_head_kv // TODO: K-shift is likely not working // TODO: change to ggml_add - kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, alibi_bias_max); + kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias); cb(kq, "kq_scaled_alibi", il); } @@ -3410,11 +3398,11 @@ static struct ggml_tensor * llm_build_kqv( // split cached v into n_head heads struct ggml_tensor * v = - ggml_view_3d(ctx, kv_self.v, + ggml_view_3d(ctx, kv.v, n_kv, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*n_ctx, - ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + ggml_element_size(kv.v)*n_ctx, + ggml_element_size(kv.v)*n_ctx*n_embd_head, + ggml_element_size(kv.v)*n_ctx*n_embd_gqa*il); cb(v, "v", il); struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); @@ -3438,1259 +3426,1011 @@ static struct ggml_tensor * llm_build_kqv( return cur; } -static struct ggml_cgraph * llm_build_llama( +struct llm_build_context { + const llama_model & model; + const llama_hparams & hparams; + const llama_cparams & cparams; + const llama_batch & batch; + const llama_kv_cache & kv_self; + + const int64_t n_embd; + const int64_t n_layer; + const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) + const int64_t n_head; + const int64_t n_head_kv; + const int64_t n_embd_head; + const int64_t n_embd_gqa; + + const float freq_base; + const float freq_scale; + const float norm_eps; + const float norm_rms_eps; + + const int32_t n_tokens; + const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx) + const int32_t kv_head; // index of where we store new KV data in the cache + + const bool do_rope_shift; + + const llm_build_cb & cb; + + llama_buffer & buf_compute; + + struct ggml_context * ctx0 = nullptr; + + // TODO: consider making the entire interface noexcept + llm_build_context( llama_context & lctx, const llama_batch & batch, const llm_build_cb & cb, - bool worst_case) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; + bool worst_case) : + model (lctx.model), + hparams (model.hparams), + cparams (lctx.cparams), + batch (batch), + kv_self (lctx.kv_self), + n_embd (hparams.n_embd), + n_layer (hparams.n_layer), + n_ctx (cparams.n_ctx), + n_head (hparams.n_head), + n_head_kv (hparams.n_head_kv), + n_embd_head (hparams.n_embd_head()), + n_embd_gqa (hparams.n_embd_gqa()), + freq_base (cparams.rope_freq_base), + freq_scale (cparams.rope_freq_scale), + norm_eps (hparams.f_norm_eps), + norm_rms_eps (hparams.f_norm_rms_eps), + n_tokens (batch.n_tokens), + n_kv (worst_case ? n_ctx : kv_self.n), + kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), + do_rope_shift (worst_case || kv_self.has_shift), + cb (cb), + buf_compute (lctx.buf_compute) { + GGML_ASSERT(!!kv_self.ctx); - const auto & kv_self = lctx.kv_self; + // all initializations should be done in init() + } - GGML_ASSERT(!!kv_self.ctx); + void init() { + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ true, + }; - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - - GGML_ASSERT(n_embd_head == hparams.n_rot); - - const float freq_base = cparams.rope_freq_base; - const float freq_scale = cparams.rope_freq_scale; - const float norm_rms_eps = hparams.f_norm_rms_eps; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = worst_case ? n_ctx : kv_self.n; - const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; - - const bool do_rope_shift = worst_case || kv_self.has_shift; - - //printf("n_kv = %d\n", n_kv); - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb); - cb(inpL, "inp_embd", -1); - - // inp_pos - contains the positions - struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - cb(inp_pos, "inp_pos", -1); - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - cb(KQ_mask, "KQ_mask", -1); - - // shift the entire K-cache if needed - if (do_rope_shift) { - llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb); + ctx0 = ggml_init(params); } - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; + void free() { + if (ctx0) { + ggml_free(ctx0); + ctx0 = nullptr; + } + } + + struct ggml_cgraph * build_llama() { + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, cur, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_baichuan() { + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + switch (model.type) { + case MODEL_7B: + Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); + Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); + break; + case MODEL_13B: + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens); + break; + default: + GGML_ASSERT(false); + } + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + // apply ALiBi for 13B model + const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f; + + cur = llm_build_kqv(ctx0, cur, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_falcon() { + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * attn_norm; + + attn_norm = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(attn_norm, "attn_norm", il); + + // self-attention + { + if (model.layers[il].attn_norm_2) { + // Falcon-40B + cur = llm_build_norm(ctx0, attn_norm, hparams, + model.layers[il].attn_norm_2, + model.layers[il].attn_norm_2_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm_2", il); + } else { + cur = attn_norm; + } + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + // using mode = 2 for neox mode + Qcur = ggml_rope_custom(ctx0, Qcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, attn_norm, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = cur; + + // feed forward + { + cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result + model.layers[il].ffn_up, NULL, + NULL, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; // norm - cur = llm_build_norm(ctx0, inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, norm_rms_eps, cb, il); - cb(cur, "attn_norm", il); + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); - // self-attention - { - // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); + ggml_build_forward_expand(gf, cur); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); - cb(Qcur, "Qcur", il); - - Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); - cb(Kcur, "Kcur", il); - - llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); - - cur = llm_build_kqv(lctx, ctx0, cur, - model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, -1.0f, cb, il); - cb(cur, "kqv_out", il); - } - - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = llm_build_norm(ctx0, ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, norm_rms_eps, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, - LLM_FFN_SILU, LLM_FFN_PAR, cb, il); - cb(cur, "ffn_out", il); - } - - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; + return gf; } - cur = inpL; + struct ggml_cgraph * build_starcoder() { + struct ggml_cgraph * gf = ggml_new_graph(ctx0); - cur = llm_build_norm(ctx0, cur, - model.output_norm, NULL, - LLM_NORM_RMS, norm_rms_eps, cb, -1); - cb(cur, "result_norm", -1); + struct ggml_tensor * cur; + struct ggml_tensor * pos; + struct ggml_tensor * inpL; - // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); - cb(cur, "result_output", -1); + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); - ggml_build_forward_expand(gf, cur); + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); - ggml_free(ctx0); + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); - return gf; -} + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); -static struct ggml_cgraph * llm_build_baichaun( - llama_context & lctx, - const llama_batch & batch, - const llm_build_cb & cb, - bool worst_case) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; + pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); + cb(pos, "pos_embd", -1); - const auto & kv_self = lctx.kv_self; + inpL = ggml_add(ctx0, inpL, pos); + cb(inpL, "inpL", -1); - GGML_ASSERT(!!kv_self.ctx); + for (int il = 0; il < n_layer; ++il) { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); + // self-attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); - GGML_ASSERT(n_embd_head == hparams.n_rot); + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); - const float freq_base = cparams.rope_freq_base; - const float freq_scale = cparams.rope_freq_scale; - const float norm_rms_eps = hparams.f_norm_rms_eps; + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = worst_case ? n_ctx : kv_self.n; - const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); - const bool do_rope_shift = worst_case || kv_self.has_shift; + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - auto & buf_compute = lctx.buf_compute; + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb); - cb(inpL, "inp_embd", -1); - - // inp_pos - contains the positions - struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - cb(inp_pos, "inp_pos", -1); - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - cb(KQ_mask, "KQ_mask", -1); - - // shift the entire K-cache if needed - if (do_rope_shift) { - llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb); - } - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; - - cur = llm_build_norm(ctx0, inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, norm_rms_eps, cb, il); - cb(cur, "attn_norm", il); - - // self-attention - { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - switch (model.type) { - case MODEL_7B: - Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); - Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale); - break; - case MODEL_13B: - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens); - break; - default: - GGML_ASSERT(false); - } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - - llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); - - // apply ALiBi for 13B model - const float alibi_bias_max = model.type == MODEL_13B ? 8.0f : -1.0f; - - cur = llm_build_kqv(lctx, ctx0, cur, - model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, alibi_bias_max, cb, il); - cb(cur, "kqv_out", il); - } - - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = llm_build_norm(ctx0, ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, norm_rms_eps, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, - LLM_FFN_SILU, LLM_FFN_PAR, cb, il); - cb(cur, "ffn_out", il); - } - - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - cur = llm_build_norm(ctx0, cur, - model.output_norm, NULL, - LLM_NORM_RMS, norm_rms_eps, cb, -1); - cb(cur, "result_norm", -1); - - // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); - cb(cur, "result_output", -1); - - ggml_build_forward_expand(gf, cur); - - ggml_free(ctx0); - - return gf; -} - -static struct ggml_cgraph * llm_build_falcon( - llama_context & lctx, - const llama_batch & batch, - const llm_build_cb & cb, - bool worst_case) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - GGML_ASSERT(n_embd_head == hparams.n_rot); - - const float freq_base = cparams.rope_freq_base; - const float freq_scale = cparams.rope_freq_scale; - const float norm_eps = hparams.f_norm_eps; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = worst_case ? n_ctx : kv_self.n; - const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; - - const bool do_rope_shift = worst_case || kv_self.has_shift; - - //printf("kv_head = %d, n_kv = %d, n_tokens = %d, n_ctx = %d, is_measure = %d, has_shift = %d\n", - // kv_head, n_kv, n_tokens, n_ctx, ggml_allocr_is_measure(lctx.alloc), kv_self.has_shift); - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb); - cb(inpL, "inp_embd", -1); - - // inp_pos - contains the positions - struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - cb(inp_pos, "inp_pos", -1); - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - cb(KQ_mask, "KQ_mask", -1); - - // shift the entire K-cache if needed - if (do_rope_shift) { - llm_build_k_shift(lctx, ctx0, gf, n_embd_head, LLM_ROPE_NEOX, cb); - } - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * attn_norm; - - attn_norm = llm_build_norm(ctx0, inpL, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, norm_eps, cb, il); - cb(attn_norm, "attn_norm", il); - - // self-attention - { - if (model.layers[il].attn_norm_2) { - // Falcon-40B - cur = llm_build_norm(ctx0, attn_norm, - model.layers[il].attn_norm_2, - model.layers[il].attn_norm_2_b, - LLM_NORM, norm_eps, cb, il); - cb(cur, "attn_norm_2", il); - } else { - cur = attn_norm; + cur = llm_build_kqv(ctx0, cur, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + cb(cur, "kqv_out", il); } - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - - // using mode = 2 for neox mode - Qcur = ggml_rope_custom(ctx0, Qcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale); - cb(Qcur, "Qcur", il); - - Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale); - cb(Kcur, "Kcur", il); - - llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); - - cur = llm_build_kqv(lctx, ctx0, attn_norm, - model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, -1.0f, cb, il); - cb(cur, "kqv_out", il); - } - - struct ggml_tensor * ffn_inp = cur; - - // feed forward - { - cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result - model.layers[il].ffn_up, NULL, - NULL, NULL, - model.layers[il].ffn_down, NULL, - LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); - cb(cur, "ffn_out", il); - } - - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "l_out", il); - - cur = ggml_add(ctx0, cur, inpL); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - // norm - cur = llm_build_norm(ctx0, cur, - model.output_norm, - model.output_norm_b, - LLM_NORM, norm_eps, cb, -1); - cb(cur, "result_norm", -1); - - cur = ggml_mul_mat(ctx0, model.output, cur); - cb(cur, "result_output", -1); - - ggml_build_forward_expand(gf, cur); - - ggml_free(ctx0); - - return gf; -} - -static struct ggml_cgraph * llm_build_starcoder( - llama_context & lctx, - const llama_batch & batch, - const llm_build_cb & cb, - bool worst_case) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - GGML_ASSERT(n_embd_head == hparams.n_rot); - - const float norm_eps = hparams.f_norm_eps; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = worst_case ? n_ctx : kv_self.n; - const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * pos; - struct ggml_tensor * inpL; - - inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb); - cb(inpL, "inp_embd", -1); - - // inp_pos - contains the positions - struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - cb(inp_pos, "inp_pos", -1); - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - cb(KQ_mask, "KQ_mask", -1); - - pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); - cb(pos, "pos_embd", -1); - - inpL = ggml_add(ctx0, inpL, pos); - cb(inpL, "inpL", -1); - - for (int il = 0; il < n_layer; ++il) { - cur = llm_build_norm(ctx0, inpL, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, norm_eps, cb, il); - cb(cur, "attn_norm", il); - - // self-attention - { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - - llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); - - cur = llm_build_kqv(lctx, ctx0, cur, - model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, -1.0f, cb, il); - cb(cur, "kqv_out", il); - } - - // add the input - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); - cb(ffn_inp, "ffn_inp", il); - - // FF - { - cur = llm_build_norm(ctx0, ffn_inp, - model.layers[il].ffn_norm, - model.layers[il].ffn_norm_b, - LLM_NORM, norm_eps, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, - LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); - cb(cur, "ffn_out", il); - } - - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); - } - - cur = llm_build_norm(ctx0, inpL, - model.output_norm, - model.output_norm_b, - LLM_NORM, norm_eps, cb, -1); - cb(cur, "result_norm", -1); - - cur = ggml_mul_mat(ctx0, model.output, cur); - cb(cur, "result_output", -1); - - ggml_build_forward_expand(gf, cur); - ggml_free(ctx0); - - return gf; -} - -static struct ggml_cgraph * llm_build_persimmon( - llama_context & lctx, - const llama_batch & batch, - const llm_build_cb & cb, - bool worst_case) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const auto & cparams = lctx.cparams; - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_head = hparams.n_head; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_rot = n_embd_head / 2; - - const float freq_base = cparams.rope_freq_base; - const float freq_scale = cparams.rope_freq_scale; - const float norm_eps = hparams.f_norm_eps; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = worst_case ? n_ctx : kv_self.n; - const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; - - const bool do_rope_shift = worst_case || kv_self.has_shift; - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb); - cb(inpL, "imp_embd", -1); - - struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - cb(inp_pos, "inp_pos", -1); - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - cb(KQ_mask, "KQ_mask", -1); - - if (do_rope_shift) { - llm_build_k_shift(lctx, ctx0, gf, n_rot, LLM_ROPE_NEOX, cb); - } - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * residual = inpL; - - cur = llm_build_norm(ctx0, inpL, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, norm_eps, cb, il); - cb(cur, "attn_norm", il); - - // self attention - { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - // split qkv - GGML_ASSERT(n_head_kv == n_head); - - struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens); - cb(tmpqkv, "tmpqkv", il); - - struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2)); - cb(tmpqkv_perm, "tmpqkv", il); - - struct ggml_tensor * tmpq = ggml_view_3d( - ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, - ggml_element_size(tmpqkv_perm) * n_embd_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, - 0 - ); - cb(tmpq, "tmpq", il); - - struct ggml_tensor * tmpk = ggml_view_3d( - ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, - ggml_element_size(tmpqkv_perm) * n_embd_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens - ); - cb(tmpk, "tmpk", il); - - // Q/K Layernorm - tmpq = llm_build_norm(ctx0, tmpq, - model.layers[il].attn_q_norm, - model.layers[il].attn_q_norm_b, - LLM_NORM, norm_eps, cb, il); - cb(tmpq, "tmpq", il); - - tmpk = llm_build_norm(ctx0, tmpk, - model.layers[il].attn_k_norm, - model.layers[il].attn_k_norm_b, - LLM_NORM, norm_eps, cb, il); - cb(tmpk, "tmpk", il); - - // RoPE the first n_rot of q/k, pass the other half, and concat. - struct ggml_tensor * qrot = ggml_view_3d( - ctx0, tmpq, n_rot, n_head, n_tokens, - ggml_element_size(tmpq) * n_embd_head, - ggml_element_size(tmpq) * n_embd_head * n_head, - 0 - ); - cb(qrot, "qrot", il); - - struct ggml_tensor * krot = ggml_view_3d( - ctx0, tmpk, n_rot, n_head, n_tokens, - ggml_element_size(tmpk) * n_embd_head, - ggml_element_size(tmpk) * n_embd_head * n_head, - 0 - ); - cb(krot, "krot", il); - - // get the second half of tmpq, e.g tmpq[n_rot:, :, :] - struct ggml_tensor * qpass = ggml_view_3d( - ctx0, tmpq, n_rot, n_head, n_tokens, - ggml_element_size(tmpq) * n_embd_head, - ggml_element_size(tmpq) * n_embd_head * n_head, - ggml_element_size(tmpq) * n_rot - ); - cb(qpass, "qpass", il); - - struct ggml_tensor * kpass = ggml_view_3d( - ctx0, tmpk, n_rot, n_head, n_tokens, - ggml_element_size(tmpk) * n_embd_head, - ggml_element_size(tmpk) * n_embd_head * n_head, - ggml_element_size(tmpk) * n_rot - ); - cb(kpass, "kpass", il); - - struct ggml_tensor * qrotated = ggml_rope_custom( - ctx0, qrot, inp_pos, n_rot, 2, 0, freq_base, freq_scale - ); - cb(qrotated, "qrotated", il); - - struct ggml_tensor * krotated = ggml_rope_custom( - ctx0, krot, inp_pos, n_rot, 2, 0, freq_base, freq_scale - ); - cb(krotated, "krotated", il); - - // ggml currently only supports concatenation on dim=2 - // so we need to permute qrot, qpass, concat, then permute back. - qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3)); - cb(qrotated, "qrotated", il); - - krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3)); - cb(krotated, "krotated", il); - - qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3)); - cb(qpass, "qpass", il); - - kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3)); - cb(kpass, "kpass", il); - - struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 1, 2, 0, 3)); - cb(Q, "Q", il); - - Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3)); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = ggml_view_3d( - ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, - ggml_element_size(tmpqkv_perm) * n_embd_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2 - ); - cb(Vcur, "Vcur", il); - - llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); - - // TODO: not tested, could be broken - cur = llm_build_kqv(lctx, ctx0, Q, - model.layers[il].wo, model.layers[il].bo, - Q, KQ_scale, KQ_mask, n_tokens, n_kv, -1.0f, cb, il); - cb(cur, "kqv_out", il); - } - - struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = llm_build_norm(ctx0, ffn_inp, - model.layers[il].ffn_norm, - model.layers[il].ffn_norm_b, - LLM_NORM, norm_eps, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, - LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il); - cb(cur, "ffn_out", il); - } - - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "l_out", il); - - inpL = cur; - } - - cur = inpL; - - cur = llm_build_norm(ctx0, cur, - model.output_norm, - model.output_norm_b, - LLM_NORM, norm_eps, cb, -1); - cb(cur, "result_norm", -1); - - cur = ggml_mul_mat(ctx0, model.output, cur); - cb(cur, "result_output", -1); - - ggml_build_forward_expand(gf, cur); - - ggml_free(ctx0); - - return gf; -} - -static struct ggml_cgraph * llm_build_refact( - llama_context & lctx, - const llama_batch & batch, - const llm_build_cb & cb, - bool worst_case) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; - const int64_t n_embd_head = hparams.n_embd_head(); - - const float norm_rms_eps = hparams.f_norm_rms_eps; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = worst_case ? n_ctx : kv_self.n; - const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, - }; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb); - cb(inpL, "inp_embd", -1); - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - cb(KQ_mask, "KQ_mask", -1); - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; - - cur = llm_build_norm(ctx0, inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, norm_rms_eps, cb, il); - cb(cur, "attn_norm", il); - - // self-attention - { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - cb(Kcur, "Kcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cb(Qcur, "Qcur", il); - - llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); - - cur = llm_build_kqv(lctx, ctx0, Qcur, - model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, 8.0f, cb, il); - cb(cur, "kqv_out", il); - } - - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = llm_build_norm(ctx0, ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, norm_rms_eps, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, - LLM_FFN_SILU, LLM_FFN_PAR, cb, il); - cb(cur, "ffn_out", il); - } - - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - cur = llm_build_norm(ctx0, cur, - model.output_norm, NULL, - LLM_NORM_RMS, norm_rms_eps, cb, -1); - cb(cur, "result_norm", -1); - - // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); - cb(cur, "result_output", -1); - - ggml_build_forward_expand(gf, cur); - - ggml_free(ctx0); - - return gf; -} - -static struct ggml_cgraph * llm_build_bloom( - llama_context & lctx, - const llama_batch & batch, - const llm_build_cb & cb, - bool worst_case) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - GGML_ASSERT(n_embd_head == hparams.n_rot); - - const float norm_eps = hparams.f_norm_eps; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = worst_case ? n_ctx : kv_self.n; - const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, - }; - - params.no_alloc = true; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb); - cb(inpL, "inp_embd", -1); - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - cb(KQ_mask, "KQ_mask", -1); - - inpL = llm_build_norm(ctx0, inpL, - model.tok_norm, - model.tok_norm_b, - LLM_NORM, norm_eps, cb, -1); - cb(inpL, "inp_norm", -1); - - for (int il = 0; il < n_layer; ++il) { - cur = llm_build_norm(ctx0, inpL, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, norm_eps, cb, il); - cb(cur, "attn_norm", il); - - // self-attention - { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - - llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); - - cur = llm_build_kqv(lctx, ctx0, Qcur, - model.layers[il].wo, model.layers[il].bo, - Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, 8.0f, cb, il); - cb(cur, "kqv_out", il); - } - - // Add the input - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); - cb(ffn_inp, "ffn_inp", il); - - // FF - { - cur = llm_build_norm(ctx0, ffn_inp, - model.layers[il].ffn_norm, - model.layers[il].ffn_norm_b, - LLM_NORM, norm_eps, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, - LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); - cb(cur, "ffn_out", il); - } - - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); - } - - cur = llm_build_norm(ctx0, inpL, - model.output_norm, - model.output_norm_b, - LLM_NORM, norm_eps, cb, -1); - cb(cur, "result_norm", -1); - - cur = ggml_mul_mat(ctx0, model.output, cur); - cb(cur, "result_output", -1); - - ggml_build_forward_expand(gf, cur); - - ggml_free(ctx0); - - return gf; -} - -static struct ggml_cgraph * llm_build_mpt( - llama_context & lctx, - const llama_batch & batch, - const llm_build_cb & cb, - bool worst_case) { - const auto & model = lctx.model; - const auto & hparams = model.hparams; - const auto & cparams = lctx.cparams; - - const auto & kv_self = lctx.kv_self; - - GGML_ASSERT(!!kv_self.ctx); - - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_ctx = cparams.n_ctx; - const int64_t n_head = hparams.n_head; - const int64_t n_embd_head = hparams.n_embd_head(); - const int64_t n_embd_gqa = hparams.n_embd_gqa(); - - const float norm_eps = hparams.f_norm_eps; - const float clamp_kqv = hparams.f_clamp_kqv; - const float max_alibi_bias = hparams.f_max_alibi_bias; - - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = worst_case ? n_ctx : kv_self.n; - const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; - - auto & buf_compute = lctx.buf_compute; - - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, - }; - - params.no_alloc = true; - - struct ggml_context * ctx0 = ggml_init(params); - - ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - inpL = llm_build_inp_embd(ctx0, batch, model.tok_embd, n_embd, n_tokens, cb); - cb(inpL, "inp_embd", -1); - - // KQ_scale - struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); - cb(KQ_scale, "KQ_scale", -1); - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - cb(KQ_mask, "KQ_mask", -1); - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * attn_norm; - - attn_norm = llm_build_norm(ctx0, inpL, - model.layers[il].attn_norm, - NULL, - LLM_NORM, norm_eps, cb, il); - cb(attn_norm, "attn_norm", il); - - // self-attention - { - cur = attn_norm; - - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - if (clamp_kqv > 0.0f) { - cur = ggml_clamp(ctx0, cur, -clamp_kqv, clamp_kqv); - cb(cur, "wqkv_clamped", il); + // add the input + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); } - struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); - struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); - struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - - llm_build_kv_store(lctx, ctx0, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); - - cur = llm_build_kqv(lctx, ctx0, Qcur, - model.layers[il].wo, NULL, - Qcur, KQ_scale, KQ_mask, n_tokens, n_kv, max_alibi_bias, cb, il); - cb(cur, "kqv_out", il); + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); } - // Add the input - struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); - cb(ffn_inp, "ffn_inp", il); + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); - // feed forward - { - cur = llm_build_norm(ctx0, ffn_inp, - model.layers[il].ffn_norm, + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_persimmon() { + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + const int64_t n_rot = n_embd_head / 2; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "imp_embd", -1); + + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * residual = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + // split qkv + GGML_ASSERT(n_head_kv == n_head); + + struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens); + cb(tmpqkv, "tmpqkv", il); + + struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2)); + cb(tmpqkv_perm, "tmpqkv", il); + + struct ggml_tensor * tmpq = ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + 0 + ); + cb(tmpq, "tmpq", il); + + struct ggml_tensor * tmpk = ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens + ); + cb(tmpk, "tmpk", il); + + // Q/K Layernorm + tmpq = llm_build_norm(ctx0, tmpq, hparams, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, cb, il); + cb(tmpq, "tmpq", il); + + tmpk = llm_build_norm(ctx0, tmpk, hparams, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, cb, il); + cb(tmpk, "tmpk", il); + + // RoPE the first n_rot of q/k, pass the other half, and concat. + struct ggml_tensor * qrot = ggml_view_3d( + ctx0, tmpq, n_rot, n_head, n_tokens, + ggml_element_size(tmpq) * n_embd_head, + ggml_element_size(tmpq) * n_embd_head * n_head, + 0 + ); + cb(qrot, "qrot", il); + + struct ggml_tensor * krot = ggml_view_3d( + ctx0, tmpk, n_rot, n_head, n_tokens, + ggml_element_size(tmpk) * n_embd_head, + ggml_element_size(tmpk) * n_embd_head * n_head, + 0 + ); + cb(krot, "krot", il); + + // get the second half of tmpq, e.g tmpq[n_rot:, :, :] + struct ggml_tensor * qpass = ggml_view_3d( + ctx0, tmpq, n_rot, n_head, n_tokens, + ggml_element_size(tmpq) * n_embd_head, + ggml_element_size(tmpq) * n_embd_head * n_head, + ggml_element_size(tmpq) * n_rot + ); + cb(qpass, "qpass", il); + + struct ggml_tensor * kpass = ggml_view_3d( + ctx0, tmpk, n_rot, n_head, n_tokens, + ggml_element_size(tmpk) * n_embd_head, + ggml_element_size(tmpk) * n_embd_head * n_head, + ggml_element_size(tmpk) * n_rot + ); + cb(kpass, "kpass", il); + + struct ggml_tensor * qrotated = ggml_rope_custom( + ctx0, qrot, inp_pos, n_rot, 2, 0, freq_base, freq_scale + ); + cb(qrotated, "qrotated", il); + + struct ggml_tensor * krotated = ggml_rope_custom( + ctx0, krot, inp_pos, n_rot, 2, 0, freq_base, freq_scale + ); + cb(krotated, "krotated", il); + + // ggml currently only supports concatenation on dim=2 + // so we need to permute qrot, qpass, concat, then permute back. + qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3)); + cb(qrotated, "qrotated", il); + + krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3)); + cb(krotated, "krotated", il); + + qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3)); + cb(qpass, "qpass", il); + + kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3)); + cb(kpass, "kpass", il); + + struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 1, 2, 0, 3)); + cb(Q, "Q", il); + + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3)); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2 + ); + cb(Vcur, "Vcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + // TODO: not tested, could be broken + cur = llm_build_kqv(ctx0, Q, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_refact() { + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + cb(Kcur, "Kcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur", il); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_bloom() { + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + inpL = llm_build_norm(ctx0, inpL, hparams, + model.tok_norm, + model.tok_norm_b, + LLM_NORM, cb, -1); + cb(inpL, "inp_norm", -1); + + for (int il = 0; il < n_layer; ++il) { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il); + cb(cur, "kqv_out", il); + } + + // Add the input + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // FF + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + } + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_mpt() { + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * attn_norm; + + attn_norm = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, - LLM_NORM, norm_eps, cb, il); - cb(cur, "ffn_norm", il); + LLM_NORM, cb, il); + cb(attn_norm, "attn_norm", il); - cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - NULL, NULL, - model.layers[il].ffn_down, NULL, - LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); - cb(cur, "ffn_out", il); + // self-attention + { + cur = attn_norm; + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + if (hparams.f_clamp_kqv > 0.0f) { + cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(cur, "wqkv_clamped", il); + } + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self, + model.layers[il].wo, NULL, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, cb, il); + cb(cur, "kqv_out", il); + } + + // Add the input + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // feed forward + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + NULL, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + NULL, NULL, + model.layers[il].ffn_down, NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "l_out", il); + cur = inpL; - // input for next layer - inpL = cur; + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, + NULL, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; } - - cur = inpL; - - cur = llm_build_norm(ctx0, cur, - model.output_norm, - NULL, - LLM_NORM, norm_eps, cb, -1); - cb(cur, "result_norm", -1); - - cur = ggml_mul_mat(ctx0, model.output, cur); - cb(cur, "result_output", -1); - - ggml_build_forward_expand(gf, cur); - - ggml_free(ctx0); - - return gf; -} +}; // // tensor offloading helpers @@ -5122,43 +4862,49 @@ static struct ggml_cgraph * llama_build_graph( struct ggml_cgraph * result = NULL; + struct llm_build_context llm(lctx, batch, cb, worst_case); + + llm.init(); + switch (model.arch) { case LLM_ARCH_LLAMA: { - result = llm_build_llama(lctx, batch, cb, worst_case); + result = llm.build_llama(); } break; case LLM_ARCH_BAICHUAN: { - result = llm_build_baichaun(lctx, batch, cb, worst_case); + result = llm.build_baichuan(); } break; case LLM_ARCH_FALCON: { - result = llm_build_falcon(lctx, batch, cb, worst_case); + result = llm.build_falcon(); } break; case LLM_ARCH_STARCODER: { - result = llm_build_starcoder(lctx, batch, cb, worst_case); + result = llm.build_starcoder(); } break; case LLM_ARCH_PERSIMMON: { - result = llm_build_persimmon(lctx, batch, cb, worst_case); + result = llm.build_persimmon(); } break; case LLM_ARCH_REFACT: { - result = llm_build_refact(lctx, batch, cb, worst_case); + result = llm.build_refact(); } break; case LLM_ARCH_BLOOM: { - result = llm_build_bloom(lctx, batch, cb, worst_case); + result = llm.build_bloom(); } break; case LLM_ARCH_MPT: { - result = llm_build_mpt(lctx, batch, cb, worst_case); + result = llm.build_mpt(); } break; default: GGML_ASSERT(false); } + llm.free(); + if (worst_case) { int n_non_view_total = 0;