From 881800d1f083c39431cef288347082be516d1c80 Mon Sep 17 00:00:00 2001 From: Seb C <47074056+Sebby37@users.noreply.github.com> Date: Tue, 21 Nov 2023 00:26:59 +1030 Subject: [PATCH] main : Add ChatML functionality to main example (#4046) Co-authored-by: Sebastian Cramond --- common/common.cpp | 3 +++ common/common.h | 1 + examples/infill/infill.cpp | 7 +++++++ examples/main/main.cpp | 36 +++++++++++++++++++++++++++++++----- 4 files changed, 42 insertions(+), 5 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 3f10b5d7f..eec704b99 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -491,6 +491,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.interactive_first = true; } else if (arg == "-ins" || arg == "--instruct") { params.instruct = true; + } else if (arg == "-cml" || arg == "--chatml") { + params.chatml = true; } else if (arg == "--infill") { params.infill = true; } else if (arg == "--multiline-input") { @@ -730,6 +732,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -i, --interactive run in interactive mode\n"); printf(" --interactive-first run in interactive mode and wait for input right away\n"); printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); + printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n"); printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); printf(" -r PROMPT, --reverse-prompt PROMPT\n"); printf(" halt generation at PROMPT, return control in interactive mode\n"); diff --git a/common/common.h b/common/common.h index cc048daab..88fa13fc0 100644 --- a/common/common.h +++ b/common/common.h @@ -102,6 +102,7 @@ struct gpt_params { bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode + bool chatml = false; // chatml mode (used for models trained on chatml syntax) bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 11f7410ed..4a7827876 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -146,6 +146,13 @@ int main(int argc, char ** argv) { return 0; } + if (params.chatml) { + printf("\n************\n"); + printf("%s: please use the 'main' tool for chatml mode\n", __func__); + printf("************\n\n"); + + return 0; + } if (!params.antiprompt.empty()) { printf("\n************\n"); printf("%s: please use the 'main' tool for antiprompt mode\n", __func__); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 99d219d65..31ec8cade 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -234,8 +234,11 @@ int main(int argc, char ** argv) { std::vector embd_inp; - if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { + if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) { LOG("tokenize the prompt\n"); + if (params.chatml) { + params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; + } embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); } else { LOG("use session tokens\n"); @@ -313,7 +316,7 @@ int main(int argc, char ** argv) { } // number of tokens to keep when resetting context - if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct) { + if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) { params.n_keep = (int)embd_inp.size(); } @@ -324,11 +327,23 @@ int main(int argc, char ** argv) { LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str()); LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str()); + // chatml prefix & suffix + const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", add_bos, true); + const auto cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true); + + LOG("cml_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_pfx).c_str()); + LOG("cml_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_sfx).c_str()); + // in instruct mode, we inject a prefix and a suffix to each input by the user if (params.instruct) { params.interactive_first = true; params.antiprompt.push_back("### Instruction:\n\n"); } + // similar for chatml mode + else if (params.chatml) { + params.interactive_first = true; + params.antiprompt.push_back("<|im_start|>user\n"); + } // enable interactive mode if interactive start is specified if (params.interactive_first) { @@ -705,7 +720,7 @@ int main(int argc, char ** argv) { is_interacting = true; printf("\n"); - } else if (params.instruct) { + } else if (params.instruct || params.chatml) { is_interacting = true; } } @@ -713,7 +728,7 @@ int main(int argc, char ** argv) { if (n_past > 0 && is_interacting) { LOG("waiting for user input\n"); - if (params.instruct) { + if (params.instruct || params.chatml) { printf("\n> "); } @@ -760,6 +775,12 @@ int main(int argc, char ** argv) { n_consumed = embd_inp.size(); embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); } + // chatml mode: insert user chat prefix + if (params.chatml && !is_antiprompt) { + LOG("inserting chatml prefix\n"); + n_consumed = embd_inp.size(); + embd_inp.insert(embd_inp.end(), cml_pfx.begin(), cml_pfx.end()); + } if (params.escape) { process_escapes(buffer); } @@ -778,6 +799,11 @@ int main(int argc, char ** argv) { LOG("inserting instruction suffix\n"); embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); } + // chatml mode: insert assistant chat suffix + if (params.chatml) { + LOG("inserting chatml suffix\n"); + embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end()); + } for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; @@ -803,7 +829,7 @@ int main(int argc, char ** argv) { } // end of text token - if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive)) { + if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) { LOG_TEE(" [end of text]\n"); break; }