From 0fd6c1f015f6cccf3b527f7dbd8386a434728126 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 Mar 2024 10:12:29 +0200 Subject: [PATCH] embedding : print cosine similarity (#899) --- common/common.cpp | 13 +++++++++++++ common/common.h | 1 + examples/embedding/embedding.cpp | 21 ++++++++++++++++----- examples/gritlm/gritlm.cpp | 26 ++++++-------------------- 4 files changed, 36 insertions(+), 25 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 73b1b61ba..58fbd05aa 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1877,3 +1877,16 @@ void llama_embd_normalize(const float * inp, float * out, int n) { } } +float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n){ + double sum = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + + for (int i = 0; i < n; i++) { + sum += embd1[i] * embd2[i]; + sum1 += embd1[i] * embd1[i]; + sum2 += embd2[i] * embd2[i]; + } + + return sum / (sqrt(sum1) * sqrt(sum2)); +} diff --git a/common/common.h b/common/common.h index 0f178b9eb..d250eef8b 100644 --- a/common/common.h +++ b/common/common.h @@ -268,3 +268,4 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40 void llama_embd_normalize(const float * inp, float * out, int n); +float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 49302a199..f390c4061 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -168,14 +168,25 @@ int main(int argc, char ** argv) { batch_decode(ctx, batch, out, s, n_embd); // print first 3 embeddings + fprintf(stdout, "\n"); for (int j = 0; j < std::min(3, n_prompts); j++) { - fprintf(stderr, "embedding %d: ", j); - for (int i = 0; i < n_embd; i++) { - fprintf(stderr, "%f ", emb[j * n_embd + i]); + fprintf(stdout, "embedding %d: ", j); + for (int i = 0; i < std::min(16, n_embd); i++) { + fprintf(stdout, "%f ", emb[j * n_embd + i]); } - fprintf(stderr, "\n\n"); + fprintf(stdout, "\n"); + } + + // print cosine similarity matrix + fprintf(stdout, "\n"); + printf("cosine similarity matrix:\n\n"); + for (int i = 0; i < n_prompts; i++) { + for (int j = 0; j < n_prompts; j++) { + float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + fprintf(stdout, "%6.2f ", sim); + } + fprintf(stdout, "\n"); } - fprintf(stderr, "\n"); // clean up llama_print_timings(ctx); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 3d4b085d6..52fd719b3 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -6,22 +6,6 @@ // #define GRIT_DEBUG -static float dot_product(const std::vector & v1, const std::vector & v2) { - float dot = 0.0f; - for (uint64_t i = 0; i < v1.size(); ++i) { - dot += v1[i] * v2[i]; - } - return dot; -} - -static float norm(const std::vector & v) { - return std::sqrt(dot_product(v, v)); -} - -static float cosine_similarity(const std::vector & v1, const std::vector & v2) { - return dot_product(v1, v2) / (norm(v1) * norm(v2)); -} - static std::vector> encode(llama_context * ctx, const std::vector & sentences, const std::string & instruction) { std::vector> result; @@ -203,10 +187,12 @@ int main(int argc, char * argv[]) { const std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); const std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); - const float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]); - const float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]); - const float cosine_sim_q1_d0 = cosine_similarity(q_rep[1], d_rep[0]); - const float cosine_sim_q1_d1 = cosine_similarity(q_rep[1], d_rep[1]); + const int n_embd = llama_n_embd(mdl); + + const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); + const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); + const float cosine_sim_q1_d0 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd); + const float cosine_sim_q1_d1 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd); std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0); std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);