From 03bf161eb6dea6400ee49c6dc6b69bdcfa9fd3fc Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Tue, 13 Feb 2024 06:06:58 -0600 Subject: [PATCH] llama : support batched embeddings (#5466) * batched embedding: pool outputs by sequence id. updated embedding example * bring back non-causal attention * embd : minor improvements * llama : minor --------- Co-authored-by: Georgi Gerganov --- convert-hf-to-gguf.py | 1 + examples/embedding/embedding.cpp | 146 +++++++++++++++++++++++-------- gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 + llama.cpp | 61 +++++++++---- llama.h | 5 ++ 6 files changed, 163 insertions(+), 54 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index cae1551a2..5adfdc143 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1648,6 +1648,7 @@ class BertModel(Model): self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) self.gguf_writer.add_causal_attention(False) + self.gguf_writer.add_pooling_layer(True) self.gguf_writer.add_file_type(self.ftype) def set_vocab(self): diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 27376c8f0..b4688cf51 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -7,6 +7,51 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +static std::vector split_lines(const std::string & s) { + std::string line; + std::vector lines; + std::stringstream ss(s); + while (std::getline(ss, line)) { + lines.push_back(line); + } + return lines; +} + +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id) { + for (size_t i = 0; i < tokens.size(); i++) { + llama_batch_add(batch, tokens[i], i, { seq_id }, false); + } +} + +static void normalize(float * vec, float * out, int n) { + float norm = 0; + for (int i = 0; i < n; i++) { + norm += vec[i] * vec[i]; + } + norm = sqrt(norm); + for (int i = 0; i < n; i++) { + out[i] = vec[i] / norm; + } +} + +static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { + // clear previous kv_cache values (irrelevant for embeddings) + llama_kv_cache_clear(ctx); + + // run model + fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); + if (llama_decode(ctx, batch) < 0) { + fprintf(stderr, "%s : failed to decode\n", __func__); + } + + // normalize on copy + for (int k = 0; k < n_seq; k++) { + float * emb = llama_get_embeddings_ith(ctx, k); + float * out = output + k * n_embd; + normalize(emb, out, n_embd); + } +} + int main(int argc, char ** argv) { gpt_params params; @@ -55,59 +100,84 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s\n", get_system_info(params).c_str()); } - int n_past = 0; + // split the prompt into lines + std::vector prompts = split_lines(params.prompt); - // tokenize the prompt - auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); + // max batch size + const uint64_t n_batch = params.n_batch; + GGML_ASSERT(params.n_batch == params.n_ctx); + // tokenize the prompts and trim + std::vector> inputs; + for (const auto & prompt : prompts) { + auto inp = ::llama_tokenize(ctx, prompt, true); + if (inp.size() > n_batch) { + inp.resize(n_batch); + } + inputs.push_back(inp); + } + + // tokenization stats if (params.verbose_prompt) { - fprintf(stderr, "\n"); - fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); - fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); - for (int i = 0; i < (int) embd_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str()); + for (int i = 0; i < (int) inputs.size(); i++) { + fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str()); + fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size()); + for (int j = 0; j < (int) inputs[i].size(); j++) { + fprintf(stderr, "%6d -> '%s'\n", inputs[i][j], llama_token_to_piece(ctx, inputs[i][j]).c_str()); + } + fprintf(stderr, "\n\n"); } - fprintf(stderr, "\n"); } - if (embd_inp.size() > (size_t)n_ctx) { - fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n", - __func__, embd_inp.size(), n_ctx); - return 1; - } - - while (!embd_inp.empty()) { - int n_tokens = std::min(params.n_batch, (int) embd_inp.size()); - if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; - } - n_past += n_tokens; - embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens); - } + // initialize batch + const int n_prompts = prompts.size(); + struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts); + // allocate output const int n_embd = llama_n_embd(model); - auto * embeddings = llama_get_embeddings(ctx); + std::vector embeddings(n_prompts * n_embd, 0); + float * emb = embeddings.data(); - // l2-normalize embeddings - float norm = 0; - for (int i = 0; i < n_embd; i++) { - norm += embeddings[i] * embeddings[i]; - } - norm = sqrt(norm); - for (int i = 0; i < n_embd; i++) { - embeddings[i] /= norm; + // break into batches + int p = 0; // number of prompts processed already + int s = 0; // number of prompts in current batch + for (int k = 0; k < n_prompts; k++) { + // clamp to n_batch tokens + auto & inp = inputs[k]; + const uint64_t n_toks = inp.size(); + + // encode if at capacity + if (batch.n_tokens + n_toks > n_batch) { + float * out = emb + p * n_embd; + batch_decode(ctx, batch, out, s, n_embd); + llama_batch_clear(batch); + p += s; + s = 0; + } + + // add to batch + batch_add_seq(batch, inp, s); + s += 1; } - for (int i = 0; i < n_embd; i++) { - printf("%f ", embeddings[i]); - } - printf("\n"); + // final batch + float * out = emb + p * n_embd; + batch_decode(ctx, batch, out, s, n_embd); + // print first 3 embeddings + for (int j = 0; j < std::min(3, n_prompts); j++) { + fprintf(stderr, "embedding %d: ", j); + for (int i = 0; i < n_embd; i++) { + fprintf(stderr, "%f ", emb[j * n_embd + i]); + } + fprintf(stderr, "\n\n"); + } + fprintf(stderr, "\n"); + + // clean up llama_print_timings(ctx); llama_free(ctx); llama_free_model(model); - llama_backend_free(); return 0; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a9c13dd38..644e1589c 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -40,6 +40,7 @@ class Keys: TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" EXPERT_COUNT = "{arch}.expert_count" EXPERT_USED_COUNT = "{arch}.expert_used_count" + POOLING_LAYER = "{arch}.pooling_layer" class Attention: HEAD_COUNT = "{arch}.attention.head_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 7af58a46c..d87bd8e88 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -360,6 +360,9 @@ class GGUFWriter: def add_causal_attention(self, value: bool) -> None: self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) + def add_pooling_layer(self, value: bool) -> None: + self.add_bool(Keys.LLM.POOLING_LAYER.format(arch=self.arch), value) + def add_rope_dimension_count(self, count: int) -> None: self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count) diff --git a/llama.cpp b/llama.cpp index 6dce392df..eb6c46f36 100644 --- a/llama.cpp +++ b/llama.cpp @@ -254,6 +254,7 @@ enum llm_kv { LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, LLM_KV_EXPERT_USED_COUNT, + LLM_KV_POOLING_LAYER, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -311,6 +312,7 @@ static std::map LLM_KV_NAMES = { { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + { LLM_KV_POOLING_LAYER, "%s.pooling_layer" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -1539,6 +1541,7 @@ struct llama_hparams { float f_max_alibi_bias; bool causal_attn = true; + bool pooling_layer = false; bool operator!=(const llama_hparams & other) const { @@ -1601,6 +1604,7 @@ struct llama_cparams { bool mul_mat_q; bool offload_kqv; + bool do_pooling; ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; @@ -1896,7 +1900,7 @@ struct llama_context { struct ggml_tensor * inp_pos; // I32 [n_batch] struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch] struct ggml_tensor * inp_K_shift; // I32 [n_ctx] - struct ggml_tensor * inp_sum; // F32 [1, n_batch] + struct ggml_tensor * inp_sum; // F32 [n_batch, n_batch] #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; @@ -3053,6 +3057,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); + ml.get_key(LLM_KV_POOLING_LAYER, hparams.pooling_layer); switch (hparams.n_layer) { case 3: @@ -4859,7 +4864,7 @@ struct llm_build_context { const int32_t n_orig_ctx; const bool do_rope_shift; - const bool causal_attn; + const bool do_pooling; const llm_build_cb & cb; @@ -4903,7 +4908,7 @@ struct llm_build_context { kv_head (worst_case ? n_ctx - n_tokens : kv_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), do_rope_shift (worst_case || kv_self.has_shift), - causal_attn (hparams.causal_attn), + do_pooling (hparams.pooling_layer && cparams.do_pooling), cb (cb), buf_compute_meta (lctx.buf_compute_meta) { // all initializations should be done in init() @@ -5752,17 +5757,18 @@ struct llm_build_context { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); struct ggml_tensor * cur; struct ggml_tensor * inpL; // get input vectors with right size + const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type); struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); - struct ggml_tensor * inp_sum = ggml_view_1d(ctx0, lctx.inp_sum, n_tokens, 0); + struct ggml_tensor * inp_sum = ggml_view_2d(ctx0, lctx.inp_sum, n_tokens, n_tokens, stride1, 0); // construct input embeddings (token, type, position) inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); + // token types are hardcoded to zero ("Sentence A") struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); inpL = ggml_add(ctx0, inpL, type_row0); @@ -5832,9 +5838,11 @@ struct llm_build_context { // final output cur = inpL; - // pooling - cur = ggml_mul_mat(ctx0, inp_sum, ggml_cont(ctx0, ggml_transpose(ctx0, cur))); - cb(cur, "result_embed", -1); + // pooling layer + if (do_pooling) { + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_sum); + } + cb(cur, "result_embd", -1); ggml_build_forward_expand(gf, cur); @@ -7367,7 +7375,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_kv; ++i) { float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || + (hparams.causal_attn && lctx.kv_self.cells[i].pos > pos)) { f = -INFINITY; } else { f = 0; @@ -7378,7 +7387,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - { assert(ggml_backend_buffer_is_host(lctx.inp_sum->buffer)); float * data = (float *) lctx.inp_sum->data; @@ -7399,6 +7407,20 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { data[i] = lctx.kv_self.cells[i].delta; } } + + if (hparams.pooling_layer && cparams.do_pooling) { + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_sum->buffer)); + float * data = (float *) lctx.inp_sum->data; + + memset(lctx.inp_sum->data, 0, batch.n_tokens * batch.n_tokens * ggml_element_size(lctx.inp_sum)); + + for (int i = 0; i < n_tokens; ++i) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + data[seq_id*n_tokens + i] = 1.0f; + } + } } // decode a batch of tokens by evaluating the transformer @@ -7510,7 +7532,7 @@ static int llama_decode_internal( embeddings = gf->nodes[gf->n_nodes - 3]; GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); } - } else if (strcmp(res->name, "result_embed") == 0) { + } else if (strcmp(res->name, "result_embd") == 0) { embeddings = res; res = nullptr; } else { @@ -7630,11 +7652,12 @@ static int llama_decode_internal( if (!lctx.embedding.empty()) { auto & embedding_out = lctx.embedding; - const int64_t embed_pos = res ? n_embd * (n_tokens-1) : 0; + const int64_t embd_pos = res ? n_embd * (n_tokens-1) : 0; + const int64_t embd_size = res ? n_embd : n_embd * n_tokens; - embedding_out.resize(n_embd); + embedding_out.resize(embd_size); ggml_backend_t embeddings_backend = ggml_backend_sched_get_node_backend(lctx.sched, embeddings); - ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embed_pos*sizeof(float), n_embd*sizeof(float)); + ggml_backend_tensor_get_async(embeddings_backend, embeddings, embedding_out.data(), embd_pos*sizeof(float), embd_size*sizeof(float)); ggml_backend_synchronize(embeddings_backend); } @@ -10950,6 +10973,7 @@ struct llama_context_params llama_context_default_params() { /*.logits_all =*/ false, /*.embedding =*/ false, /*.offload_kqv =*/ true, + /*.do_pooling =*/ true, }; return result; @@ -11105,6 +11129,7 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.mul_mat_q = params.mul_mat_q; cparams.offload_kqv = params.offload_kqv; + cparams.do_pooling = params.do_pooling; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; @@ -11252,7 +11277,7 @@ struct llama_context * llama_new_context_with_model( // resized during inference, reserve maximum ctx->logits.reserve(hparams.n_vocab*cparams.n_batch); - if (params.embedding){ + if (params.embedding) { ctx->embedding.resize(hparams.n_embd); } @@ -11270,7 +11295,7 @@ struct llama_context * llama_new_context_with_model( ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch); ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx); - ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, 1, cparams.n_batch); + ctx->inp_sum = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch); ggml_set_name(ctx->inp_tokens, "inp_tokens"); ggml_set_name(ctx->inp_embd, "inp_embd"); @@ -12128,6 +12153,10 @@ float * llama_get_embeddings(struct llama_context * ctx) { return ctx->embedding.data(); } +float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { + return ctx->embedding.data() + i*ctx->model.hparams.n_embd; +} + const char * llama_token_get_text(const struct llama_model * model, llama_token token) { return model->vocab.id_to_token[token].text.c_str(); } diff --git a/llama.h b/llama.h index 367e8f1a1..5ef78ec96 100644 --- a/llama.h +++ b/llama.h @@ -236,6 +236,7 @@ extern "C" { bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embedding; // embedding mode only bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU + bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer) }; // model quantization parameters @@ -628,6 +629,10 @@ extern "C" { // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); + // Get the embeddings for the ith sequence + // llama_get_embeddings(ctx) + i*n_embd + LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); + // // Vocab //