diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index 36f9bf4..7960ab7 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -333,27 +333,10 @@ int main(int argc, char ** argv) { prompt_llama = ::replace(prompt_llama, "{4}", chat_symb); - // evaluate the initial prompt - - auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true); - - printf("\n"); - printf("%s : initializing - please wait ...\n", __func__); - - if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; - } - - if (params.verbose_prompt) { - fprintf(stdout, "\n"); - fprintf(stdout, "%s", prompt_llama.c_str()); - fflush(stdout); - } - // init session std::string path_session = params.path_session; std::vector session_tokens; + auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true); if (!path_session.empty()) { fprintf(stderr, "%s: attempting to load saved session from %s\n", __func__, path_session.c_str()); @@ -370,6 +353,9 @@ int main(int argc, char ** argv) { return 1; } session_tokens.resize(n_token_count_out); + for (size_t i = 0; i < session_tokens.size(); i++) { + embd_inp[i] = session_tokens[i]; + } fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size()); } else { @@ -377,6 +363,22 @@ int main(int argc, char ** argv) { } } + // evaluate the initial prompt + + printf("\n"); + printf("%s : initializing - please wait ...\n", __func__); + + if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + + if (params.verbose_prompt) { + fprintf(stdout, "\n"); + fprintf(stdout, "%s", prompt_llama.c_str()); + fflush(stdout); + } + // debug message about similarity of saved session, if applicable size_t n_matching_session_tokens = 0; if (session_tokens.size()) { @@ -417,7 +419,7 @@ int main(int argc, char ** argv) { int n_past = n_keep; int n_prev = 64; // TODO arg - int n_session_consumed = 0; + int n_session_consumed = !path_session.empty() && session_tokens.size() > 0 ? session_tokens.size() : 0; std::vector embd; @@ -494,6 +496,11 @@ int main(int argc, char ** argv) { embd = ::llama_tokenize(ctx_llama, text_heard, false); + // Append the new input tokens to the session_tokens vector + if (!path_session.empty()) { + session_tokens.insert(session_tokens.end(), tokens.begin(), tokens.end()); + } + // text inference bool done = false; std::string text_to_speak; @@ -539,20 +546,21 @@ int main(int argc, char ** argv) { } } + if (embd.size() > 0 && !path_session.empty()) { + session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); + n_session_consumed = session_tokens.size(); + } + if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } } - //printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size()); embd_inp.insert(embd_inp.end(), embd.begin(), embd.end()); n_past += embd.size(); - if (embd.size() > 0 && !path_session.empty()) { - session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); - n_session_consumed = session_tokens.size(); - } + embd.clear(); if (done) break;