diff --git a/common/common.cpp b/common/common.cpp index d7f650ef4..16ef4d7f7 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1852,3 +1852,18 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) { printf("\n=== Done dumping\n"); } + +void llama_embd_normalize(const float * inp, float * out, int n) { + double sum = 0.0; + for (int i = 0; i < n; i++) { + sum += inp[i] * inp[i]; + } + sum = sqrt(sum); + + const float norm = sum > 0.0 ? 1.0f / sum : 0.0f; + + for (int i = 0; i < n; i++) { + out[i] = inp[i] * norm; + } +} + diff --git a/common/common.h b/common/common.h index 977ce419f..f8d82b871 100644 --- a/common/common.h +++ b/common/common.h @@ -260,3 +260,10 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); // Dump the KV cache view showing individual sequences in each cell (long output). void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); + +// +// Embedding utils +// + +void llama_embd_normalize(const float * inp, float * out, int n); + diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index ff5883da6..a553ae1c3 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -23,17 +23,6 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke } } -static void normalize(const float * vec, float * out, int n) { - float norm = 0; - for (int i = 0; i < n; i++) { - norm += vec[i] * vec[i]; - } - norm = sqrt(norm); - for (int i = 0; i < n; i++) { - out[i] = vec[i] / norm; - } -} - static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(ctx); @@ -44,7 +33,6 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu fprintf(stderr, "%s : failed to decode\n", __func__); } - // normalize on copy for (int i = 0; i < batch.n_tokens; i++) { if (!batch.logits[i]) { continue; @@ -61,7 +49,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } float * out = output + batch.seq_id[i][0] * n_embd; - normalize(embd, out, n_embd); + llama_embd_normalize(embd, out, n_embd); } } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8cff514f2..796f3499c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1327,6 +1327,8 @@ struct server_context { const int n_embd = llama_n_embd(model); + std::vector embd_res(n_embd, 0.0f); + for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { continue; @@ -1350,8 +1352,10 @@ struct server_context { continue; } + llama_embd_normalize(embd, embd_res.data(), n_embd); + res.data = json { - {"embedding", std::vector(embd, embd + n_embd)}, + {"embedding", embd_res}, }; } @@ -3354,6 +3358,8 @@ int main(int argc, char ** argv) { // get the result server_task_result result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); + + // append to the responses responses.push_back(result.data); }