From e72e4158debb04126a0fabedf0452a5551780ea0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 28 Jan 2024 19:44:10 +0200 Subject: [PATCH] talk-llama : sync llama.cpp --- examples/talk-llama/llama.cpp | 282 ++++++++++++++++++++++++++++++++-- examples/talk-llama/llama.h | 5 +- 2 files changed, 276 insertions(+), 11 deletions(-) diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index b03b67e..f7d054c 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -11,6 +11,10 @@ # include "ggml-cuda.h" #elif defined(GGML_USE_CLBLAST) # include "ggml-opencl.h" +#elif defined(GGML_USE_VULKAN) +# include "ggml-vulkan.h" +#elif defined(GGML_USE_SYCL) +# include "ggml-sycl.h" #endif #ifdef GGML_USE_METAL @@ -52,6 +56,7 @@ #include #include #include +#include #include #include #include @@ -196,6 +201,7 @@ enum llm_arch { LLM_ARCH_PHI2, LLM_ARCH_PLAMO, LLM_ARCH_CODESHELL, + LLM_ARCH_ORION, LLM_ARCH_UNKNOWN, }; @@ -217,6 +223,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_CODESHELL, "codeshell" }, + { LLM_ARCH_ORION, "orion" }, }; enum llm_kv { @@ -641,6 +648,25 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_ORION, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, @@ -1256,8 +1282,14 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer if (host_buffer) { buft = ggml_backend_cuda_host_buffer_type(); } +#elif defined(GGML_USE_SYCL) + buft = ggml_backend_sycl_host_buffer_type(); #elif defined(GGML_USE_CPU_HBM) buft = ggml_backend_cpu_hbm_buffer_type(); +#elif defined(GGML_USE_VULKAN) + if (host_buffer) { + buft = ggml_backend_vk_host_buffer_type(); + } #endif if (buft == nullptr) { @@ -1275,6 +1307,10 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(int gpu) { buft = ggml_backend_metal_buffer_type(); #elif defined(GGML_USE_CUBLAS) buft = ggml_backend_cuda_buffer_type(gpu); +#elif defined(GGML_USE_VULKAN) + buft = ggml_backend_vk_buffer_type(); +#elif defined(GGML_USE_SYCL) + buft = ggml_backend_sycl_buffer_type(gpu); #elif defined(GGML_USE_CLBLAST) buft = ggml_backend_opencl_buffer_type(); #endif @@ -1332,6 +1368,7 @@ enum e_model { MODEL_7B, MODEL_8B, MODEL_13B, + MODEL_14B, MODEL_15B, MODEL_30B, MODEL_34B, @@ -2683,6 +2720,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_7B: return "7B"; case MODEL_8B: return "8B"; case MODEL_13B: return "13B"; + case MODEL_14B: return "14B"; case MODEL_15B: return "15B"; case MODEL_30B: return "30B"; case MODEL_34B: return "34B"; @@ -2950,7 +2988,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_ORION: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 40: model.type = e_model::MODEL_14B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -3933,6 +3979,38 @@ static bool llm_load_tensors( layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; + case LLM_ARCH_ORION: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; + + default: throw std::runtime_error("unknown architecture"); } @@ -4563,6 +4641,126 @@ struct llm_build_context { ctx0 = nullptr; } } + struct ggml_cgraph * build_orion() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); + cb(inpL, "inp_embd", -1); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); + cb(inp_pos, "inp_pos", -1); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0); + cb(KQ_mask, "KQ_mask", -1); + + // shift the entire K-cache if needed + if (do_rope_shift) { + llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb); + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + // if (model.layers[il].bq) { + // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + // cb(Qcur, "Qcur", il); + // } + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + // if (model.layers[il].bk) { + // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + // cb(Kcur, "Kcur", il); + // } + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + // if (model.layers[il].bv) { + // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + // cb(Vcur, "Vcur", il); + // } + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il); + cb(cur, "kqv_out", il); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -6520,6 +6718,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_codeshell(); } break; + case LLM_ARCH_ORION: + { + result = llm.build_orion(); + } break; default: GGML_ASSERT(false); } @@ -6652,7 +6854,7 @@ static int llama_decode_internal( } const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 1; - if (ggml_cpu_has_cublas() && fully_offloaded) { + if ((ggml_cpu_has_cublas() || ggml_cpu_has_vulkan()) && fully_offloaded) { n_threads = 1; } @@ -7946,6 +8148,11 @@ void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * c } void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { + // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast + // if (k >= (int32_t)candidates->size) { + // return; + // } + const int64_t t_start_sample_us = ggml_time_us(); k = std::max(k, (int) min_keep); @@ -8054,21 +8261,56 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can return; } - llama_sample_softmax(ctx, candidates); - const int64_t t_start_sample_us = ggml_time_us(); - float scale = candidates->data[0].p; // scale by max prob - size_t i = 1; // first token always matches + bool min_p_applied = false; - for (; i < candidates->size; ++i) { - if (candidates->data[i].p < p * scale && i >= min_keep) { - break; // prob too small + // if the candidates aren't sorted, try the unsorted implementation first + if (!candidates->sorted) { + std::vector filtered_tokens; + + float max_logit = -FLT_MAX; + for (size_t i = 0; i < candidates->size; ++i) { + max_logit = std::max(max_logit, candidates->data[i].logit); + } + const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max + + for (size_t i = 0; i < candidates->size; ++i) { + if (candidates->data[i].logit >= min_logit) { + filtered_tokens.push_back(candidates->data[i]); + } + } + + // if we have enough values the operation was a success + if (filtered_tokens.size() >= min_keep) { + memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); + candidates->size = filtered_tokens.size(); + min_p_applied = true; } } - // Resize the output vector to keep only the matching tokens - candidates->size = i; + // if the candidates are sorted or the unsorted implementation failed, use this implementation + if (!min_p_applied) { + // Sort the logits in descending order + if (!candidates->sorted) { + std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + candidates->sorted = true; + } + + const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max + size_t i = 1; // first token always matches + + for (; i < candidates->size; ++i) { + if (candidates->data[i].logit < min_logit && i >= min_keep) { + break; // prob too small + } + } + + // Resize the output vector to keep only the matching tokens + candidates->size = i; + } if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; @@ -9997,6 +10239,26 @@ struct llama_context * llama_new_context_with_model( } } } +#elif defined(GGML_USE_VULKAN) + if (model->n_gpu_layers > 0) { + ggml_backend_t backend = ggml_backend_vk_init(); + if (backend == nullptr) { + LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__); + llama_free(ctx); + return nullptr; + } + ctx->backends.push_back(backend); + } +#elif defined(GGML_USE_SYCL) + if (model->n_gpu_layers > 0) { + ggml_backend_t backend = ggml_backend_sycl_init(model->main_gpu); + if (backend == nullptr) { + LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d backend\n", __func__, model->main_gpu); + llama_free(ctx); + return nullptr; + } + ctx->backends.push_back(backend); + } #endif ctx->backend_cpu = ggml_backend_cpu_init(); if (ctx->backend_cpu == nullptr) { diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 7b3634a..3e33072 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -6,6 +6,9 @@ #ifdef GGML_USE_CUBLAS #include "ggml-cuda.h" #define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES +#elif defined(GGML_USE_SYCL) +#include "ggml-sycl.h" +#define LLAMA_MAX_DEVICES GGML_SYCL_MAX_DEVICES #else #define LLAMA_MAX_DEVICES 1 #endif // GGML_USE_CUBLAS @@ -46,7 +49,7 @@ #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN #define LLAMA_SESSION_VERSION 4 -#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || defined(GGML_USE_SYCL) // Defined when llama.cpp is compiled with support for offloading model layers to GPU. #define LLAMA_SUPPORTS_GPU_OFFLOAD #endif