Add interactive mode (#61)

* Initial work on interactive mode.

* Improve interactive mode. Make rev. prompt optional.

* Update README to explain interactive mode.

* Fix OS X build
This commit is contained in:
Matvey Soloviev 2023-03-12 22:13:28 +01:00 committed by GitHub
parent 9661954835
commit 96ea727f47
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 170 additions and 10 deletions

View file

@ -183,6 +183,29 @@ The number of files generated for each model is as follows:
When running the larger models, make sure you have enough disk space to store all the intermediate files. When running the larger models, make sure you have enough disk space to store all the intermediate files.
### Interactive mode
If you want a more ChatGPT-like experience, you can run in interactive mode by passing `-i` as a parameter.
In this mode, you can always interrupt generation by pressing Ctrl+C and enter one or more lines of text which will be converted into tokens and appended to the current context. You can also specify a *reverse prompt* with the parameter `-r "reverse prompt string"`. This will result in user input being prompted whenever the exact tokens of the reverse prompt string are encountered in the generation. A typical use is to use a prompt which makes LLaMa emulate a chat between multiple users, say Alice and Bob, and pass `-r "Alice:"`.
Here is an example few-shot interaction, invoked with the command
```
./main -m ./models/13B/ggml-model-q4_0.bin -t 8 --repeat_penalty 1.2 --temp 0.9 --top_p 0.9 -n 256 \
--color -i -r "User:" \
-p \
"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User: Hello, Bob.
Bob: Hello. How may I help you today?
User: Please tell me the largest city in Europe.
Bob: Sure. The largest city in Europe is London, the capital of the United Kingdom.
User:"
```
Note the use of `--color` to distinguish between user input and generated text.
![image](https://user-images.githubusercontent.com/401380/224572787-d418782f-47b2-49c4-a04e-65bfa7ad4ec0.png)
## Limitations ## Limitations
- Not sure if my tokenizer is correct. There are a few places where we might have a mistake: - Not sure if my tokenizer is correct. There are a few places where we might have a mistake:

137
main.cpp
View file

@ -11,6 +11,18 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <signal.h>
#include <unistd.h>
#define ANSI_COLOR_RED "\x1b[31m"
#define ANSI_COLOR_GREEN "\x1b[32m"
#define ANSI_COLOR_YELLOW "\x1b[33m"
#define ANSI_COLOR_BLUE "\x1b[34m"
#define ANSI_COLOR_MAGENTA "\x1b[35m"
#define ANSI_COLOR_CYAN "\x1b[36m"
#define ANSI_COLOR_RESET "\x1b[0m"
#define ANSI_BOLD "\x1b[1m"
// determine number of model parts based on the dimension // determine number of model parts based on the dimension
static const std::map<int, int> LLAMA_N_PARTS = { static const std::map<int, int> LLAMA_N_PARTS = {
{ 4096, 1 }, { 4096, 1 },
@ -733,6 +745,18 @@ bool llama_eval(
return true; return true;
} }
static bool is_interacting = false;
void sigint_handler(int signo) {
if (signo == SIGINT) {
if (!is_interacting) {
is_interacting=true;
} else {
_exit(130);
}
}
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_time_init(); ggml_time_init();
const int64_t t_main_start_us = ggml_time_us(); const int64_t t_main_start_us = ggml_time_us();
@ -787,6 +811,9 @@ int main(int argc, char ** argv) {
params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size());
// tokenize the reverse prompt
std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false);
printf("\n"); printf("\n");
printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
@ -794,6 +821,24 @@ int main(int argc, char ** argv) {
printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str()); printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str());
} }
printf("\n"); printf("\n");
if (params.interactive) {
struct sigaction sigint_action;
sigint_action.sa_handler = sigint_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
printf("%s: interactive mode on.\n", __func__);
if(antiprompt_inp.size()) {
printf("%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str());
printf("%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
printf("%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
}
printf("\n");
}
}
printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
printf("\n\n"); printf("\n\n");
@ -807,7 +852,28 @@ int main(int argc, char ** argv) {
std::vector<gpt_vocab::id> last_n_tokens(last_n_size); std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) {
if (params.interactive) {
printf("== Running in interactive mode. ==\n"
" - Press Ctrl+C to interject at any time.\n"
" - Press Return to return control to LLaMa.\n"
" - If you want to submit another line, end your input in '\\'.\n");
}
int remaining_tokens = params.n_predict;
int input_consumed = 0;
bool input_noecho = false;
// prompt user immediately after the starting prompt has been loaded
if (params.interactive_start) {
is_interacting = true;
}
if (params.use_color) {
printf(ANSI_COLOR_YELLOW);
}
while (remaining_tokens > 0) {
// predict // predict
if (embd.size() > 0) { if (embd.size() > 0) {
const int64_t t_start_us = ggml_time_us(); const int64_t t_start_us = ggml_time_us();
@ -823,8 +889,8 @@ int main(int argc, char ** argv) {
n_past += embd.size(); n_past += embd.size();
embd.clear(); embd.clear();
if (i >= embd_inp.size()) { if (embd_inp.size() <= input_consumed) {
// sample next token // out of input, sample next token
const float top_k = params.top_k; const float top_k = params.top_k;
const float top_p = params.top_p; const float top_p = params.top_p;
const float temp = params.temp; const float temp = params.temp;
@ -847,24 +913,74 @@ int main(int argc, char ** argv) {
// add it to the context // add it to the context
embd.push_back(id); embd.push_back(id);
// echo this to console
input_noecho = false;
// decrement remaining sampling budget
--remaining_tokens;
} else { } else {
// if here, it means we are still processing the input prompt // if here, it means we are still processing the input prompt
for (int k = i; k < embd_inp.size(); k++) { while (embd_inp.size() > input_consumed) {
embd.push_back(embd_inp[k]); embd.push_back(embd_inp[input_consumed]);
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[k]); last_n_tokens.push_back(embd_inp[input_consumed]);
++input_consumed;
if (embd.size() > params.n_batch) { if (embd.size() > params.n_batch) {
break; break;
} }
} }
i += embd.size() - 1;
if (params.use_color && embd_inp.size() <= input_consumed) {
printf(ANSI_COLOR_RESET);
}
} }
// display text // display text
for (auto id : embd) { if (!input_noecho) {
printf("%s", vocab.id_to_token[id].c_str()); for (auto id : embd) {
printf("%s", vocab.id_to_token[id].c_str());
}
fflush(stdout);
}
// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive && embd_inp.size() <= input_consumed) {
// check for reverse prompt
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
// reverse prompt found
is_interacting = true;
}
if (is_interacting) {
// currently being interactive
bool another_line=true;
while (another_line) {
char buf[256] = {0};
int n_read;
if(params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
scanf("%255[^\n]%n%*c", buf, &n_read);
if(params.use_color) printf(ANSI_COLOR_RESET);
if (n_read > 0 && buf[n_read-1]=='\\') {
another_line = true;
buf[n_read-1] = '\n';
buf[n_read] = 0;
} else {
another_line = false;
buf[n_read] = '\n';
buf[n_read+1] = 0;
}
std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buf, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
input_noecho = true; // do not echo this again
}
is_interacting = false;
}
} }
fflush(stdout);
// end of text token // end of text token
if (embd.back() == 2) { if (embd.back() == 2) {
@ -873,6 +989,7 @@ int main(int argc, char ** argv) {
} }
} }
// report timing // report timing
{ {
const int64_t t_main_end_us = ggml_time_us(); const int64_t t_main_end_us = ggml_time_us();

View file

@ -49,6 +49,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.n_batch = std::stoi(argv[++i]); params.n_batch = std::stoi(argv[++i]);
} else if (arg == "-m" || arg == "--model") { } else if (arg == "-m" || arg == "--model") {
params.model = argv[++i]; params.model = argv[++i];
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "--interactive-start") {
params.interactive = true;
params.interactive_start = true;
} else if (arg == "--color") {
params.use_color = true;
} else if (arg == "-r" || arg == "--reverse-prompt") {
params.antiprompt = argv[++i];
} else if (arg == "-h" || arg == "--help") { } else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params); gpt_print_usage(argc, argv, params);
exit(0); exit(0);
@ -67,6 +76,11 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -i, --interactive run in interactive mode\n");
fprintf(stderr, " --interactive-start run in interactive mode and poll user input at startup\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");

View file

@ -28,6 +28,12 @@ struct gpt_params {
std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt; std::string prompt;
bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode
bool interactive_start = false; // reverse prompt immediately
std::string antiprompt = ""; // string upon seeing which more user input is prompted
}; };
bool gpt_params_parse(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params);