From 41654efea879bbdf4fd794e13335929d4cf0eb90 Mon Sep 17 00:00:00 2001 From: DannyDaemonic Date: Mon, 8 May 2023 19:45:48 -0700 Subject: [PATCH] Interface improvements and `--multiline-input` (previously `--author-mode`) (#1040) * Interface improvements * Multiline input * Track character width * Works with all characters and control codes + Windows console fixes --- examples/common.cpp | 384 +++++++++++++++++++++++++++++++++++------ examples/common.h | 25 ++- examples/main/main.cpp | 60 +++---- 3 files changed, 374 insertions(+), 95 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 6af440272..23d69e7d5 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -14,20 +14,16 @@ #include #endif -#if defined (_WIN32) +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#define NOMINMAX +#include #include #include -#pragma comment(lib,"kernel32.lib") -extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); -extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int CodePage, unsigned long dwFlags, - const wchar_t * lpWideCharStr, int cchWideChar, - char * lpMultiByteStr, int cbMultiByte, - const char * lpDefaultChar, bool * lpUsedDefaultChar); -#define CP_UTF8 65001 +#else +#include +#include +#include #endif int32_t get_num_physical_cores() { @@ -269,6 +265,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.interactive_first = true; } else if (arg == "-ins" || arg == "--instruct") { params.instruct = true; + } else if (arg == "--multiline-input") { + params.multiline_input = true; } else if (arg == "--color") { params.use_color = true; } else if (arg == "--mlock") { @@ -359,6 +357,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -i, --interactive run in interactive mode\n"); fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n"); fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); + fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); fprintf(stderr, " specified more than once for multiple prompts).\n"); @@ -479,54 +478,339 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { return lctx; } -/* Keep track of current color of output, and emit ANSI code if it changes. */ -void set_console_color(console_state & con_st, console_color_t color) { - if (con_st.use_color && con_st.color != color) { - switch(color) { - case CONSOLE_COLOR_DEFAULT: - printf(ANSI_COLOR_RESET); - break; - case CONSOLE_COLOR_PROMPT: - printf(ANSI_COLOR_YELLOW); - break; - case CONSOLE_COLOR_USER_INPUT: - printf(ANSI_BOLD ANSI_COLOR_GREEN); - break; - } - con_st.color = color; - } -} - -#if defined (_WIN32) -void win32_console_init(bool enable_color) { - unsigned long dwMode = 0; - void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) - if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { - hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12) - if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) { - hConOut = 0; +void console_init(console_state & con_st) { +#if defined(_WIN32) + // Windows-specific console initialization + DWORD dwMode = 0; + con_st.hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + if (con_st.hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(con_st.hConsole, &dwMode)) { + con_st.hConsole = GetStdHandle(STD_ERROR_HANDLE); + if (con_st.hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(con_st.hConsole, &dwMode))) { + con_st.hConsole = NULL; } } - if (hConOut) { + if (con_st.hConsole) { // Enable ANSI colors on Windows 10+ - if (enable_color && !(dwMode & 0x4)) { - SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) + if (con_st.use_color && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { + SetConsoleMode(con_st.hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); } // Set console output codepage to UTF8 SetConsoleOutputCP(CP_UTF8); } - void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10) - if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { + HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); + if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { // Set console input codepage to UTF16 _setmode(_fileno(stdin), _O_WTEXT); + + // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT) + dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT); + SetConsoleMode(hConIn, dwMode); + } +#else + // POSIX-specific console initialization + struct termios new_termios; + tcgetattr(STDIN_FILENO, &con_st.prev_state); + new_termios = con_st.prev_state; + new_termios.c_lflag &= ~(ICANON | ECHO); + new_termios.c_cc[VMIN] = 1; + new_termios.c_cc[VTIME] = 0; + tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); + + con_st.tty = fopen("/dev/tty", "w+"); + if (con_st.tty != nullptr) { + con_st.out = con_st.tty; + } +#endif + setlocale(LC_ALL, ""); +} + +void console_cleanup(console_state & con_st) { + // Reset console color + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + +#if !defined(_WIN32) + if (con_st.tty != nullptr) { + con_st.out = stdout; + fclose(con_st.tty); + con_st.tty = nullptr; + } + // Restore the terminal settings on POSIX systems + tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state); +#endif +} + +/* Keep track of current color of output, and emit ANSI code if it changes. */ +void console_set_color(console_state & con_st, console_color_t color) { + if (con_st.use_color && con_st.color != color) { + fflush(stdout); + switch(color) { + case CONSOLE_COLOR_DEFAULT: + fprintf(con_st.out, ANSI_COLOR_RESET); + break; + case CONSOLE_COLOR_PROMPT: + fprintf(con_st.out, ANSI_COLOR_YELLOW); + break; + case CONSOLE_COLOR_USER_INPUT: + fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN); + break; + } + con_st.color = color; + fflush(con_st.out); } } -// Convert a wide Unicode string to an UTF8 string -void win32_utf8_encode(const std::wstring & wstr, std::string & str) { - int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), NULL, 0, NULL, NULL); - std::string strTo(size_needed, 0); - WideCharToMultiByte(CP_UTF8, 0, &wstr[0], (int)wstr.size(), &strTo[0], size_needed, NULL, NULL); - str = strTo; -} +char32_t getchar32() { + wchar_t wc = getwchar(); + if (static_cast(wc) == WEOF) { + return WEOF; + } + +#if WCHAR_MAX == 0xFFFF + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + wchar_t low_surrogate = getwchar(); + if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate + return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000; + } + } + if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair + return 0xFFFD; // Return the replacement character U+FFFD + } #endif + + return static_cast(wc); +} + +void pop_cursor(console_state & con_st) { +#if defined(_WIN32) + if (con_st.hConsole != NULL) { + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo); + + COORD newCursorPosition = bufferInfo.dwCursorPosition; + if (newCursorPosition.X == 0) { + newCursorPosition.X = bufferInfo.dwSize.X - 1; + newCursorPosition.Y -= 1; + } else { + newCursorPosition.X -= 1; + } + + SetConsoleCursorPosition(con_st.hConsole, newCursorPosition); + return; + } +#endif + putc('\b', con_st.out); +} + +int estimateWidth(char32_t codepoint) { +#if defined(_WIN32) + return 1; +#else + return wcwidth(codepoint); +#endif +} + +int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t length, int expectedWidth) { +#if defined(_WIN32) + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + if (!GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo)) { + // go with the default + return expectedWidth; + } + COORD initialPosition = bufferInfo.dwCursorPosition; + DWORD nNumberOfChars = length; + WriteConsole(con_st.hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL); + + CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; + GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); + + // Figure out our real position if we're in the last column + if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { + DWORD nNumberOfChars; + WriteConsole(con_st.hConsole, &" \b", 2, &nNumberOfChars, NULL); + GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); + } + + int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; + if (width < 0) { + width += newBufferInfo.dwSize.X; + } + return width; +#else + // we can trust expectedWidth if we've got one + if (expectedWidth >= 0 || con_st.tty == nullptr) { + fwrite(utf8_codepoint, length, 1, con_st.out); + return expectedWidth; + } + + fputs("\033[6n", con_st.tty); // Query cursor position + int x1, x2, y1, y2; + int results = 0; + results = fscanf(con_st.tty, "\033[%d;%dR", &y1, &x1); + + fwrite(utf8_codepoint, length, 1, con_st.tty); + + fputs("\033[6n", con_st.tty); // Query cursor position + results += fscanf(con_st.tty, "\033[%d;%dR", &y2, &x2); + + if (results != 4) { + return expectedWidth; + } + + int width = x2 - x1; + if (width < 0) { + // Calculate the width considering text wrapping + struct winsize w; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); + width += w.ws_col; + } + return width; +#endif +} + +void replace_last(console_state & con_st, char ch) { +#if defined(_WIN32) + pop_cursor(con_st); + put_codepoint(con_st, &ch, 1, 1); +#else + fprintf(con_st.out, "\b%c", ch); +#endif +} + +void append_utf8(char32_t ch, std::string & out) { + if (ch <= 0x7F) { + out.push_back(static_cast(ch)); + } else if (ch <= 0x7FF) { + out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0xFFFF) { + out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0x10FFFF) { + out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); + out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else { + // Invalid Unicode code point + } +} + +// Helper function to remove the last UTF-8 character from a string +void pop_back_utf8_char(std::string & line) { + if (line.empty()) { + return; + } + + size_t pos = line.length() - 1; + + // Find the start of the last UTF-8 character (checking up to 4 bytes back) + for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { + if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character + } + line.erase(pos); +} + +bool console_readline(console_state & con_st, std::string & line) { + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + if (con_st.out != stdout) { + fflush(stdout); + } + + line.clear(); + std::vector widths; + bool is_special_char = false; + bool end_of_stream = false; + + char32_t input_char; + while (true) { + fflush(con_st.out); // Ensure all output is displayed before waiting for input + input_char = getchar32(); + + if (input_char == '\r' || input_char == '\n') { + break; + } + + if (input_char == WEOF || input_char == 0x04 /* Ctrl+D*/) { + end_of_stream = true; + break; + } + + if (is_special_char) { + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + replace_last(con_st, line.back()); + is_special_char = false; + } + + if (input_char == '\033') { // Escape sequence + char32_t code = getchar32(); + if (code == '[' || code == 0x1B) { + // Discard the rest of the escape sequence + while ((code = getchar32()) != WEOF) { + if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { + break; + } + } + } + } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace + if (!widths.empty()) { + int count; + do { + count = widths.back(); + widths.pop_back(); + // Move cursor back, print space, and move cursor back again + for (int i = 0; i < count; i++) { + replace_last(con_st, ' '); + pop_cursor(con_st); + } + pop_back_utf8_char(line); + } while (count == 0 && !widths.empty()); + } + } else { + int offset = line.length(); + append_utf8(input_char, line); + int width = put_codepoint(con_st, line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); + if (width < 0) { + width = 0; + } + widths.push_back(width); + } + + if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { + console_set_color(con_st, CONSOLE_COLOR_PROMPT); + replace_last(con_st, line.back()); + is_special_char = true; + } + } + + bool has_more = con_st.multiline_input; + if (is_special_char) { + replace_last(con_st, ' '); + pop_cursor(con_st); + + char last = line.back(); + line.pop_back(); + if (last == '\\') { + line += '\n'; + fputc('\n', con_st.out); + has_more = !has_more; + } else { + // llama will just eat the single space, it won't act as a space + if (line.length() == 1 && line.back() == ' ') { + line.clear(); + pop_cursor(con_st); + } + has_more = false; + } + } else { + if (end_of_stream) { + has_more = false; + } else { + line += '\n'; + fputc('\n', con_st.out); + } + } + + fflush(con_st.out); + return has_more; +} diff --git a/examples/common.h b/examples/common.h index 842e1516f..43f1cc9ef 100644 --- a/examples/common.h +++ b/examples/common.h @@ -10,6 +10,11 @@ #include #include +#if !defined (_WIN32) +#include +#include +#endif + // // CLI argument parsing // @@ -56,6 +61,7 @@ struct gpt_params { bool embedding = false; // get only sentence embedding bool interactive_first = false; // wait for user input immediately + bool multiline_input = false; // reverse the usage of `\` bool instruct = false; // instruction mode (used for Alpaca models) bool penalize_nl = true; // consider newlines as a repeatable token @@ -104,13 +110,20 @@ enum console_color_t { }; struct console_state { + bool multiline_input = false; bool use_color = false; console_color_t color = CONSOLE_COLOR_DEFAULT; + + FILE* out = stdout; +#if defined (_WIN32) + void* hConsole; +#else + FILE* tty = nullptr; + termios prev_state; +#endif }; -void set_console_color(console_state & con_st, console_color_t color); - -#if defined (_WIN32) -void win32_console_init(bool enable_color); -void win32_utf8_encode(const std::wstring & wstr, std::string & str); -#endif +void console_init(console_state & con_st); +void console_cleanup(console_state & con_st); +void console_set_color(console_state & con_st, console_color_t color); +bool console_readline(console_state & con_st, std::string & line); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 045093c72..6e1172a48 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -35,12 +35,12 @@ static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - printf("\n"); // this also force flush stdout. if (signo == SIGINT) { if (!is_interacting) { is_interacting=true; } else { + console_cleanup(con_st); + printf("\n"); llama_print_timings(*g_ctx); _exit(130); } @@ -59,10 +59,9 @@ int main(int argc, char ** argv) { // save choice to use color for later // (note for later: this is a slightly awkward choice) con_st.use_color = params.use_color; - -#if defined (_WIN32) - win32_console_init(params.use_color); -#endif + con_st.multiline_input = params.multiline_input; + console_init(con_st); + atexit([]() { console_cleanup(con_st); }); if (params.perplexity) { printf("\n************\n"); @@ -275,12 +274,21 @@ int main(int argc, char ** argv) { std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); if (params.interactive) { + const char *control_message; + if (con_st.multiline_input) { + control_message = " - To return control to LLaMa, end your input with '\\'.\n" + " - To return control without starting a new line, end your input with '/'.\n"; + } else { + control_message = " - Press Return to return control to LLaMa.\n" + " - To return control without starting a new line, end your input with '/'.\n" + " - If you want to submit another line, end your input with '\\'.\n"; + } fprintf(stderr, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif - " - Press Return to return control to LLaMa.\n" - " - If you want to submit another line, end your input in '\\'.\n\n"); + "%s\n", control_message); + is_interacting = params.interactive_first; } @@ -299,7 +307,7 @@ int main(int argc, char ** argv) { int n_session_consumed = 0; // the first thing we will do is to output the prompt, so set color accordingly - set_console_color(con_st, CONSOLE_COLOR_PROMPT); + console_set_color(con_st, CONSOLE_COLOR_PROMPT); std::vector embd; @@ -498,7 +506,7 @@ int main(int argc, char ** argv) { } // reset color to default if we there is no pending user input if (input_echo && (int)embd_inp.size() == n_consumed) { - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); } // in interactive mode, and not currently processing queued inputs; @@ -518,17 +526,12 @@ int main(int argc, char ** argv) { if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { is_interacting = true; is_antiprompt = true; - set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); - fflush(stdout); break; } } } if (n_past > 0 && is_interacting) { - // potentially set color to indicate we are taking user input - set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); - if (params.instruct) { printf("\n> "); } @@ -542,31 +545,12 @@ int main(int argc, char ** argv) { std::string line; bool another_line = true; do { -#if defined(_WIN32) - std::wstring wline; - if (!std::getline(std::wcin, wline)) { - // input stream is bad or EOF received - return 0; - } - win32_utf8_encode(wline, line); -#else - if (!std::getline(std::cin, line)) { - // input stream is bad or EOF received - return 0; - } -#endif - if (!line.empty()) { - if (line.back() == '\\') { - line.pop_back(); // Remove the continue character - } else { - another_line = false; - } - buffer += line + '\n'; // Append the line to the result - } + another_line = console_readline(con_st, line); + buffer += line; } while (another_line); // done taking input, reset color - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); // Add tokens to embd only if the input buffer is non-empty // Entering a empty line lets the user pass control back @@ -622,7 +606,5 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); - set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - return 0; }