From a47e812a5454721e5c6e93062c520a3d6f8303b2 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Wed, 29 Mar 2023 16:01:14 -0400 Subject: [PATCH] talk-llama : add alpaca support (#668) --- examples/talk-llama/prompts/talk-alpaca.txt | 23 ++++++++++++++ examples/talk-llama/talk-llama.cpp | 33 +++++++++++++++++---- 2 files changed, 50 insertions(+), 6 deletions(-) create mode 100644 examples/talk-llama/prompts/talk-alpaca.txt diff --git a/examples/talk-llama/prompts/talk-alpaca.txt b/examples/talk-llama/prompts/talk-alpaca.txt new file mode 100644 index 0000000..79b9610 --- /dev/null +++ b/examples/talk-llama/prompts/talk-alpaca.txt @@ -0,0 +1,23 @@ +Below is an instruction that describes a task. Write a response that appropriately completes the request. + +### Instruction: + +Write a text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}. +{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}’s requests immediately and with details and precision. +There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other. +The transcript only includes text, it does not include markup like HTML and Markdown. +{1} responds with short and concise answers. + +### Response: + +{0}{4} Hello, {1}! +{1}{4} Hello {0}! How may I help you today? +{0}{4} What time is it? +{1}{4} It is {2} o'clock. +{0}{4} What year is it? +{1}{4} We are in {3}. +{0}{4} What is a cat? +{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae. +{0}{4} Name a color. +{1}{4} Blue +{0}{4} diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index c7690f1..af5309c 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -33,6 +33,8 @@ struct whisper_params { int32_t max_tokens = 32; int32_t audio_ctx = 0; + int32_t n_parts_llama = -1; + float vad_thold = 0.6f; float freq_thold = 100.0f; @@ -41,12 +43,14 @@ struct whisper_params { bool print_special = false; bool print_energy = false; bool no_timestamps = true; + bool verbose_prompt = false; std::string person = "Georgi"; std::string language = "en"; std::string model_wsp = "models/ggml-base.en.bin"; std::string model_llama = "models/ggml-llama-7B.bin"; std::string speak = "./examples/talk/speak.sh"; + std::string prompt = ""; std::string fname_out; }; @@ -67,15 +71,24 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } + else if (arg == "--n-parts-llama") { params.n_parts_llama = std::stoi(argv[++i]); } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } + else if (arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; } else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; } else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; } + else if (arg == "--prompt-file") { + std::ifstream file(argv[++i]); + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); + if (params.prompt.back() == '\n') { + params.prompt.pop_back(); + } + } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); @@ -108,7 +121,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str()); fprintf(stderr, " -mg FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str()); + fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama); fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str()); + fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", ""); + fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false"); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); fprintf(stderr, "\n"); } @@ -183,8 +199,7 @@ std::string transcribe( const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)"; -// need to have leading ' ' -const std::string k_prompt_llama = R"( Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}. +const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}. {1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}’s requests immediately and with details and precision. There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other. The transcript only includes text, it does not include markup like HTML and Markdown. @@ -227,6 +242,7 @@ int main(int argc, char ** argv) { lparams.n_ctx = 512; lparams.seed = 1; lparams.f16_kv = true; + lparams.n_parts = params.n_parts_llama; struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams); @@ -278,7 +294,10 @@ int main(int argc, char ** argv) { const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", bot_name); // construct the initial prompt for LLaMA inference - std::string prompt_llama = k_prompt_llama; + std::string prompt_llama = params.prompt.empty() ? k_prompt_llama : params.prompt; + + // need to have leading ' ' + prompt_llama.insert(0, 1, ' '); prompt_llama = ::replace(prompt_llama, "{0}", params.person); prompt_llama = ::replace(prompt_llama, "{1}", bot_name); @@ -323,9 +342,11 @@ int main(int argc, char ** argv) { return 1; } - //fprintf(stdout, "\n"); - //fprintf(stdout, "%s", prompt_llama.c_str()); - //fflush(stdout); + if (params.verbose_prompt) { + fprintf(stdout, "\n"); + fprintf(stdout, "%s", prompt_llama.c_str()); + fflush(stdout); + } printf("%s : done! start speaking in the microphone\n", __func__); printf("\n");