From 3498588e0fb4daf040c4e3c698595cb0bfd345c0 Mon Sep 17 00:00:00 2001 From: DannyDaemonic Date: Fri, 4 Aug 2023 08:20:12 -0700 Subject: [PATCH] Add --simple-io option for subprocesses and break out console.h and cpp (#1558) --- Makefile | 5 +- examples/CMakeLists.txt | 2 + examples/common.cpp | 377 +----------------------------- examples/common.h | 45 +--- examples/console.cpp | 494 ++++++++++++++++++++++++++++++++++++++++ examples/console.h | 19 ++ examples/main/main.cpp | 29 ++- 7 files changed, 536 insertions(+), 435 deletions(-) create mode 100644 examples/console.cpp create mode 100644 examples/console.h diff --git a/Makefile b/Makefile index a692a39ea..e0528aeee 100644 --- a/Makefile +++ b/Makefile @@ -340,6 +340,9 @@ llama.o: llama.cpp ggml.h ggml-alloc.h ggml-cuda.h ggml-metal.h llama.h llama-ut common.o: examples/common.cpp examples/common.h $(CXX) $(CXXFLAGS) -c $< -o $@ +console.o: examples/console.cpp examples/console.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + grammar-parser.o: examples/grammar-parser.cpp examples/grammar-parser.h $(CXX) $(CXXFLAGS) -c $< -o $@ @@ -353,7 +356,7 @@ clean: # Examples # -main: examples/main/main.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) +main: examples/main/main.cpp build-info.h ggml.o llama.o common.o console.o grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) @echo @echo '==== Run ./main -h for help. ====' diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 4b1f1cf44..a7b26776a 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -13,6 +13,8 @@ set(TARGET common) add_library(${TARGET} OBJECT common.h common.cpp + console.h + console.cpp grammar-parser.h grammar-parser.cpp ) diff --git a/examples/common.cpp b/examples/common.cpp index 3e7c3b696..21f4a0357 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -25,7 +25,6 @@ #else #include #include -#include #endif #if defined(_MSC_VER) @@ -329,6 +328,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.instruct = true; } else if (arg == "--multiline-input") { params.multiline_input = true; + } else if (arg == "--simple-io") { + params.simple_io = true; } else if (arg == "--color") { params.use_color = true; } else if (arg == "--mlock") { @@ -598,6 +599,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " --mtest compute maximum memory usage\n"); fprintf(stdout, " --export export the computation graph to 'llama.ggml'\n"); fprintf(stdout, " --verbose-prompt print prompt before generation\n"); + fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); fprintf(stdout, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stdout, " -m FNAME, --model FNAME\n"); @@ -690,376 +692,3 @@ std::tuple llama_init_from_gpt_par return std::make_tuple(model, lctx); } - -void console_init(console_state & con_st) { -#if defined(_WIN32) - // Windows-specific console initialization - DWORD dwMode = 0; - con_st.hConsole = GetStdHandle(STD_OUTPUT_HANDLE); - if (con_st.hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(con_st.hConsole, &dwMode)) { - con_st.hConsole = GetStdHandle(STD_ERROR_HANDLE); - if (con_st.hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(con_st.hConsole, &dwMode))) { - con_st.hConsole = NULL; - } - } - if (con_st.hConsole) { - // Enable ANSI colors on Windows 10+ - if (con_st.use_color && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { - SetConsoleMode(con_st.hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); - } - // Set console output codepage to UTF8 - SetConsoleOutputCP(CP_UTF8); - } - HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); - if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { - // Set console input codepage to UTF16 - _setmode(_fileno(stdin), _O_WTEXT); - - // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT) - dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT); - SetConsoleMode(hConIn, dwMode); - } -#else - // POSIX-specific console initialization - struct termios new_termios; - tcgetattr(STDIN_FILENO, &con_st.prev_state); - new_termios = con_st.prev_state; - new_termios.c_lflag &= ~(ICANON | ECHO); - new_termios.c_cc[VMIN] = 1; - new_termios.c_cc[VTIME] = 0; - tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); - - con_st.tty = fopen("/dev/tty", "w+"); - if (con_st.tty != nullptr) { - con_st.out = con_st.tty; - } - - setlocale(LC_ALL, ""); -#endif -} - -void console_cleanup(console_state & con_st) { - // Reset console color - console_set_color(con_st, CONSOLE_COLOR_DEFAULT); - -#if !defined(_WIN32) - if (con_st.tty != nullptr) { - con_st.out = stdout; - fclose(con_st.tty); - con_st.tty = nullptr; - } - // Restore the terminal settings on POSIX systems - tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state); -#endif -} - -/* Keep track of current color of output, and emit ANSI code if it changes. */ -void console_set_color(console_state & con_st, console_color_t color) { - if (con_st.use_color && con_st.color != color) { - fflush(stdout); - switch(color) { - case CONSOLE_COLOR_DEFAULT: - fprintf(con_st.out, ANSI_COLOR_RESET); - break; - case CONSOLE_COLOR_PROMPT: - fprintf(con_st.out, ANSI_COLOR_YELLOW); - break; - case CONSOLE_COLOR_USER_INPUT: - fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN); - break; - case CONSOLE_COLOR_ERROR: - fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_RED); - break; - } - con_st.color = color; - fflush(con_st.out); - } -} - -char32_t getchar32() { -#if defined(_WIN32) - HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE); - wchar_t high_surrogate = 0; - - while (true) { - INPUT_RECORD record; - DWORD count; - if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) { - return WEOF; - } - - if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) { - wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar; - if (wc == 0) { - continue; - } - - if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate - high_surrogate = wc; - continue; - } else if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate - if (high_surrogate != 0) { // Check if we have a high surrogate - return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000; - } - } - - high_surrogate = 0; // Reset the high surrogate - return static_cast(wc); - } - } -#else - wchar_t wc = getwchar(); - if (static_cast(wc) == WEOF) { - return WEOF; - } - -#if WCHAR_MAX == 0xFFFF - if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate - wchar_t low_surrogate = getwchar(); - if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate - return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000; - } - } - if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair - return 0xFFFD; // Return the replacement character U+FFFD - } -#endif - - return static_cast(wc); -#endif -} - -void pop_cursor(console_state & con_st) { -#if defined(_WIN32) - if (con_st.hConsole != NULL) { - CONSOLE_SCREEN_BUFFER_INFO bufferInfo; - GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo); - - COORD newCursorPosition = bufferInfo.dwCursorPosition; - if (newCursorPosition.X == 0) { - newCursorPosition.X = bufferInfo.dwSize.X - 1; - newCursorPosition.Y -= 1; - } else { - newCursorPosition.X -= 1; - } - - SetConsoleCursorPosition(con_st.hConsole, newCursorPosition); - return; - } -#endif - putc('\b', con_st.out); -} - -int estimateWidth(char32_t codepoint) { -#if defined(_WIN32) - return 1; -#else - return wcwidth(codepoint); -#endif -} - -int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t length, int expectedWidth) { -#if defined(_WIN32) - CONSOLE_SCREEN_BUFFER_INFO bufferInfo; - if (!GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo)) { - // go with the default - return expectedWidth; - } - COORD initialPosition = bufferInfo.dwCursorPosition; - DWORD nNumberOfChars = length; - WriteConsole(con_st.hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL); - - CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; - GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); - - // Figure out our real position if we're in the last column - if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { - DWORD nNumberOfChars; - WriteConsole(con_st.hConsole, &" \b", 2, &nNumberOfChars, NULL); - GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); - } - - int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; - if (width < 0) { - width += newBufferInfo.dwSize.X; - } - return width; -#else - // we can trust expectedWidth if we've got one - if (expectedWidth >= 0 || con_st.tty == nullptr) { - fwrite(utf8_codepoint, length, 1, con_st.out); - return expectedWidth; - } - - fputs("\033[6n", con_st.tty); // Query cursor position - int x1, x2, y1, y2; - int results = 0; - results = fscanf(con_st.tty, "\033[%d;%dR", &y1, &x1); - - fwrite(utf8_codepoint, length, 1, con_st.tty); - - fputs("\033[6n", con_st.tty); // Query cursor position - results += fscanf(con_st.tty, "\033[%d;%dR", &y2, &x2); - - if (results != 4) { - return expectedWidth; - } - - int width = x2 - x1; - if (width < 0) { - // Calculate the width considering text wrapping - struct winsize w; - ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); - width += w.ws_col; - } - return width; -#endif -} - -void replace_last(console_state & con_st, char ch) { -#if defined(_WIN32) - pop_cursor(con_st); - put_codepoint(con_st, &ch, 1, 1); -#else - fprintf(con_st.out, "\b%c", ch); -#endif -} - -void append_utf8(char32_t ch, std::string & out) { - if (ch <= 0x7F) { - out.push_back(static_cast(ch)); - } else if (ch <= 0x7FF) { - out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); - out.push_back(static_cast(0x80 | (ch & 0x3F))); - } else if (ch <= 0xFFFF) { - out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); - out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); - out.push_back(static_cast(0x80 | (ch & 0x3F))); - } else if (ch <= 0x10FFFF) { - out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); - out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); - out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); - out.push_back(static_cast(0x80 | (ch & 0x3F))); - } else { - // Invalid Unicode code point - } -} - -// Helper function to remove the last UTF-8 character from a string -void pop_back_utf8_char(std::string & line) { - if (line.empty()) { - return; - } - - size_t pos = line.length() - 1; - - // Find the start of the last UTF-8 character (checking up to 4 bytes back) - for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { - if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character - } - line.erase(pos); -} - -bool console_readline(console_state & con_st, std::string & line) { - console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); - if (con_st.out != stdout) { - fflush(stdout); - } - - line.clear(); - std::vector widths; - bool is_special_char = false; - bool end_of_stream = false; - - char32_t input_char; - while (true) { - fflush(con_st.out); // Ensure all output is displayed before waiting for input - input_char = getchar32(); - - if (input_char == '\r' || input_char == '\n') { - break; - } - - if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) { - end_of_stream = true; - break; - } - - if (is_special_char) { - console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); - replace_last(con_st, line.back()); - is_special_char = false; - } - - if (input_char == '\033') { // Escape sequence - char32_t code = getchar32(); - if (code == '[' || code == 0x1B) { - // Discard the rest of the escape sequence - while ((code = getchar32()) != (char32_t) WEOF) { - if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { - break; - } - } - } - } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace - if (!widths.empty()) { - int count; - do { - count = widths.back(); - widths.pop_back(); - // Move cursor back, print space, and move cursor back again - for (int i = 0; i < count; i++) { - replace_last(con_st, ' '); - pop_cursor(con_st); - } - pop_back_utf8_char(line); - } while (count == 0 && !widths.empty()); - } - } else { - int offset = line.length(); - append_utf8(input_char, line); - int width = put_codepoint(con_st, line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); - if (width < 0) { - width = 0; - } - widths.push_back(width); - } - - if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { - console_set_color(con_st, CONSOLE_COLOR_PROMPT); - replace_last(con_st, line.back()); - is_special_char = true; - } - } - - bool has_more = con_st.multiline_input; - if (is_special_char) { - replace_last(con_st, ' '); - pop_cursor(con_st); - - char last = line.back(); - line.pop_back(); - if (last == '\\') { - line += '\n'; - fputc('\n', con_st.out); - has_more = !has_more; - } else { - // llama will just eat the single space, it won't act as a space - if (line.length() == 1 && line.back() == ' ') { - line.clear(); - pop_cursor(con_st); - } - has_more = false; - } - } else { - if (end_of_stream) { - has_more = false; - } else { - line += '\n'; - fputc('\n', con_st.out); - } - } - - fflush(con_st.out); - return has_more; -} diff --git a/examples/common.h b/examples/common.h index 974484207..375bc0a3d 100644 --- a/examples/common.h +++ b/examples/common.h @@ -11,11 +11,6 @@ #include #include -#if !defined (_WIN32) -#include -#include -#endif - // // CLI argument parsing // @@ -85,6 +80,7 @@ struct gpt_params { bool embedding = false; // get only sentence embedding bool interactive_first = false; // wait for user input immediately bool multiline_input = false; // reverse the usage of `\` + bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool instruct = false; // instruction mode (used for Alpaca models) @@ -116,42 +112,3 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s std::tuple llama_init_from_gpt_params(const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); - -// -// Console utils -// - -#define ANSI_COLOR_RED "\x1b[31m" -#define ANSI_COLOR_GREEN "\x1b[32m" -#define ANSI_COLOR_YELLOW "\x1b[33m" -#define ANSI_COLOR_BLUE "\x1b[34m" -#define ANSI_COLOR_MAGENTA "\x1b[35m" -#define ANSI_COLOR_CYAN "\x1b[36m" -#define ANSI_COLOR_RESET "\x1b[0m" -#define ANSI_BOLD "\x1b[1m" - -enum console_color_t { - CONSOLE_COLOR_DEFAULT=0, - CONSOLE_COLOR_PROMPT, - CONSOLE_COLOR_USER_INPUT, - CONSOLE_COLOR_ERROR -}; - -struct console_state { - bool multiline_input = false; - bool use_color = false; - console_color_t color = CONSOLE_COLOR_DEFAULT; - - FILE* out = stdout; -#if defined (_WIN32) - void* hConsole; -#else - FILE* tty = nullptr; - termios prev_state; -#endif -}; - -void console_init(console_state & con_st); -void console_cleanup(console_state & con_st); -void console_set_color(console_state & con_st, console_color_t color); -bool console_readline(console_state & con_st, std::string & line); diff --git a/examples/console.cpp b/examples/console.cpp new file mode 100644 index 000000000..4c32f3b09 --- /dev/null +++ b/examples/console.cpp @@ -0,0 +1,494 @@ +#include "console.h" +#include +#include + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#define ANSI_COLOR_RED "\x1b[31m" +#define ANSI_COLOR_GREEN "\x1b[32m" +#define ANSI_COLOR_YELLOW "\x1b[33m" +#define ANSI_COLOR_BLUE "\x1b[34m" +#define ANSI_COLOR_MAGENTA "\x1b[35m" +#define ANSI_COLOR_CYAN "\x1b[36m" +#define ANSI_COLOR_RESET "\x1b[0m" +#define ANSI_BOLD "\x1b[1m" + +namespace console { + + // + // Console state + // + + static bool advanced_display = false; + static bool simple_io = true; + static display_t current_display = reset; + + static FILE* out = stdout; + +#if defined (_WIN32) + static void* hConsole; +#else + static FILE* tty = nullptr; + static termios initial_state; +#endif + + // + // Init and cleanup + // + + void init(bool use_simple_io, bool use_advanced_display) { + advanced_display = use_advanced_display; + simple_io = use_simple_io; +#if defined(_WIN32) + // Windows-specific console initialization + DWORD dwMode = 0; + hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) { + hConsole = GetStdHandle(STD_ERROR_HANDLE); + if (hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(hConsole, &dwMode))) { + hConsole = nullptr; + simple_io = true; + } + } + if (hConsole) { + // Enable ANSI colors on Windows 10+ + if (advanced_display && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { + SetConsoleMode(hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); + } + // Set console output codepage to UTF8 + SetConsoleOutputCP(CP_UTF8); + } + HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); + if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { + // Set console input codepage to UTF16 + _setmode(_fileno(stdin), _O_WTEXT); + + if (!simple_io) { + // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT) + dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT); + } + if (!SetConsoleMode(hConIn, dwMode)) { + simple_io = true; + } + } +#else + // POSIX-specific console initialization + if (!simple_io) { + struct termios new_termios; + tcgetattr(STDIN_FILENO, &initial_state); + new_termios = initial_state; + new_termios.c_lflag &= ~(ICANON | ECHO); + new_termios.c_cc[VMIN] = 1; + new_termios.c_cc[VTIME] = 0; + tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); + + tty = fopen("/dev/tty", "w+"); + if (tty != nullptr) { + out = tty; + } + } + + setlocale(LC_ALL, ""); +#endif + } + + void cleanup() { + // Reset console display + set_display(reset); + +#if !defined(_WIN32) + // Restore settings on POSIX systems + if (!simple_io) { + if (tty != nullptr) { + out = stdout; + fclose(tty); + tty = nullptr; + } + tcsetattr(STDIN_FILENO, TCSANOW, &initial_state); + } +#endif + } + + // + // Display and IO + // + + // Keep track of current display and only emit ANSI code if it changes + void set_display(display_t display) { + if (advanced_display && current_display != display) { + fflush(stdout); + switch(display) { + case reset: + fprintf(out, ANSI_COLOR_RESET); + break; + case prompt: + fprintf(out, ANSI_COLOR_YELLOW); + break; + case user_input: + fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN); + break; + case error: + fprintf(out, ANSI_BOLD ANSI_COLOR_RED); + } + current_display = display; + fflush(out); + } + } + + char32_t getchar32() { +#if defined(_WIN32) + HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE); + wchar_t high_surrogate = 0; + + while (true) { + INPUT_RECORD record; + DWORD count; + if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) { + return WEOF; + } + + if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) { + wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar; + if (wc == 0) { + continue; + } + + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + high_surrogate = wc; + continue; + } + if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate + if (high_surrogate != 0) { // Check if we have a high surrogate + return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000; + } + } + + high_surrogate = 0; // Reset the high surrogate + return static_cast(wc); + } + } +#else + wchar_t wc = getwchar(); + if (static_cast(wc) == WEOF) { + return WEOF; + } + +#if WCHAR_MAX == 0xFFFF + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + wchar_t low_surrogate = getwchar(); + if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate + return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000; + } + } + if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair + return 0xFFFD; // Return the replacement character U+FFFD + } +#endif + + return static_cast(wc); +#endif + } + + void pop_cursor() { +#if defined(_WIN32) + if (hConsole != NULL) { + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + GetConsoleScreenBufferInfo(hConsole, &bufferInfo); + + COORD newCursorPosition = bufferInfo.dwCursorPosition; + if (newCursorPosition.X == 0) { + newCursorPosition.X = bufferInfo.dwSize.X - 1; + newCursorPosition.Y -= 1; + } else { + newCursorPosition.X -= 1; + } + + SetConsoleCursorPosition(hConsole, newCursorPosition); + return; + } +#endif + putc('\b', out); + } + + int estimateWidth(char32_t codepoint) { +#if defined(_WIN32) + return 1; +#else + return wcwidth(codepoint); +#endif + } + + int put_codepoint(const char* utf8_codepoint, size_t length, int expectedWidth) { +#if defined(_WIN32) + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) { + // go with the default + return expectedWidth; + } + COORD initialPosition = bufferInfo.dwCursorPosition; + DWORD nNumberOfChars = length; + WriteConsole(hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL); + + CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; + GetConsoleScreenBufferInfo(hConsole, &newBufferInfo); + + // Figure out our real position if we're in the last column + if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { + DWORD nNumberOfChars; + WriteConsole(hConsole, &" \b", 2, &nNumberOfChars, NULL); + GetConsoleScreenBufferInfo(hConsole, &newBufferInfo); + } + + int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; + if (width < 0) { + width += newBufferInfo.dwSize.X; + } + return width; +#else + // We can trust expectedWidth if we've got one + if (expectedWidth >= 0 || tty == nullptr) { + fwrite(utf8_codepoint, length, 1, out); + return expectedWidth; + } + + fputs("\033[6n", tty); // Query cursor position + int x1; + int y1; + int x2; + int y2; + int results = 0; + results = fscanf(tty, "\033[%d;%dR", &y1, &x1); + + fwrite(utf8_codepoint, length, 1, tty); + + fputs("\033[6n", tty); // Query cursor position + results += fscanf(tty, "\033[%d;%dR", &y2, &x2); + + if (results != 4) { + return expectedWidth; + } + + int width = x2 - x1; + if (width < 0) { + // Calculate the width considering text wrapping + struct winsize w; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); + width += w.ws_col; + } + return width; +#endif + } + + void replace_last(char ch) { +#if defined(_WIN32) + pop_cursor(); + put_codepoint(&ch, 1, 1); +#else + fprintf(out, "\b%c", ch); +#endif + } + + void append_utf8(char32_t ch, std::string & out) { + if (ch <= 0x7F) { + out.push_back(static_cast(ch)); + } else if (ch <= 0x7FF) { + out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0xFFFF) { + out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0x10FFFF) { + out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); + out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else { + // Invalid Unicode code point + } + } + + // Helper function to remove the last UTF-8 character from a string + void pop_back_utf8_char(std::string & line) { + if (line.empty()) { + return; + } + + size_t pos = line.length() - 1; + + // Find the start of the last UTF-8 character (checking up to 4 bytes back) + for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { + if ((line[pos] & 0xC0) != 0x80) { + break; // Found the start of the character + } + } + line.erase(pos); + } + + bool readline_advanced(std::string & line, bool multiline_input) { + if (out != stdout) { + fflush(stdout); + } + + line.clear(); + std::vector widths; + bool is_special_char = false; + bool end_of_stream = false; + + char32_t input_char; + while (true) { + fflush(out); // Ensure all output is displayed before waiting for input + input_char = getchar32(); + + if (input_char == '\r' || input_char == '\n') { + break; + } + + if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) { + end_of_stream = true; + break; + } + + if (is_special_char) { + set_display(user_input); + replace_last(line.back()); + is_special_char = false; + } + + if (input_char == '\033') { // Escape sequence + char32_t code = getchar32(); + if (code == '[' || code == 0x1B) { + // Discard the rest of the escape sequence + while ((code = getchar32()) != (char32_t) WEOF) { + if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { + break; + } + } + } + } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace + if (!widths.empty()) { + int count; + do { + count = widths.back(); + widths.pop_back(); + // Move cursor back, print space, and move cursor back again + for (int i = 0; i < count; i++) { + replace_last(' '); + pop_cursor(); + } + pop_back_utf8_char(line); + } while (count == 0 && !widths.empty()); + } + } else { + int offset = line.length(); + append_utf8(input_char, line); + int width = put_codepoint(line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); + if (width < 0) { + width = 0; + } + widths.push_back(width); + } + + if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { + set_display(prompt); + replace_last(line.back()); + is_special_char = true; + } + } + + bool has_more = multiline_input; + if (is_special_char) { + replace_last(' '); + pop_cursor(); + + char last = line.back(); + line.pop_back(); + if (last == '\\') { + line += '\n'; + fputc('\n', out); + has_more = !has_more; + } else { + // llama will just eat the single space, it won't act as a space + if (line.length() == 1 && line.back() == ' ') { + line.clear(); + pop_cursor(); + } + has_more = false; + } + } else { + if (end_of_stream) { + has_more = false; + } else { + line += '\n'; + fputc('\n', out); + } + } + + fflush(out); + return has_more; + } + + bool readline_simple(std::string & line, bool multiline_input) { +#if defined(_WIN32) + std::wstring wline; + if (!std::getline(std::wcin, wline)) { + // Input stream is bad or EOF received + line.clear(); + GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0); + return false; + } + + int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL); + line.resize(size_needed); + WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL); +#else + if (!std::getline(std::cin, line)) { + // Input stream is bad or EOF received + line.clear(); + return false; + } +#endif + if (!line.empty()) { + char last = line.back(); + if (last == '/') { // Always return control on '/' symbol + line.pop_back(); + return false; + } + if (last == '\\') { // '\\' changes the default action + line.pop_back(); + multiline_input = !multiline_input; + } + } + line += '\n'; + + // By default, continue input if multiline_input is set + return multiline_input; + } + + bool readline(std::string & line, bool multiline_input) { + set_display(user_input); + + if (simple_io) { + return readline_simple(line, multiline_input); + } + return readline_advanced(line, multiline_input); + } + +} diff --git a/examples/console.h b/examples/console.h new file mode 100644 index 000000000..ec175269b --- /dev/null +++ b/examples/console.h @@ -0,0 +1,19 @@ +// Console functions + +#pragma once + +#include + +namespace console { + enum display_t { + reset = 0, + prompt, + user_input, + error + }; + + void init(bool use_simple_io, bool use_advanced_display); + void cleanup(); + void set_display(display_t display); + bool readline(std::string & line, bool multiline_input); +} diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3796a9230..56ada7e69 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -4,6 +4,7 @@ #endif #include "common.h" +#include "console.h" #include "llama.h" #include "build-info.h" #include "grammar-parser.h" @@ -35,9 +36,7 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static console_state con_st; static llama_context ** g_ctx; - static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) @@ -46,7 +45,7 @@ void sigint_handler(int signo) { if (!is_interacting) { is_interacting=true; } else { - console_cleanup(con_st); + console::cleanup(); printf("\n"); llama_print_timings(*g_ctx); _exit(130); @@ -64,10 +63,8 @@ int main(int argc, char ** argv) { // save choice to use color for later // (note for later: this is a slightly awkward choice) - con_st.use_color = params.use_color; - con_st.multiline_input = params.multiline_input; - console_init(con_st); - atexit([]() { console_cleanup(con_st); }); + console::init(params.simple_io, params.use_color); + atexit([]() { console::cleanup(); }); if (params.perplexity) { printf("\n************\n"); @@ -373,7 +370,7 @@ int main(int argc, char ** argv) { if (params.interactive) { const char *control_message; - if (con_st.multiline_input) { + if (params.multiline_input) { control_message = " - To return control to LLaMa, end your input with '\\'.\n" " - To return control without starting a new line, end your input with '/'.\n"; } else { @@ -401,7 +398,7 @@ int main(int argc, char ** argv) { int n_past_guidance = 0; // the first thing we will do is to output the prompt, so set color accordingly - console_set_color(con_st, CONSOLE_COLOR_PROMPT); + console::set_display(console::prompt); std::vector embd; std::vector embd_guidance; @@ -422,9 +419,9 @@ int main(int argc, char ** argv) { // Ensure the input doesn't exceed the context size by truncating embd if necessary. if ((int)embd.size() > max_embd_size) { auto skipped_tokens = embd.size() - max_embd_size; - console_set_color(con_st, CONSOLE_COLOR_ERROR); + console::set_display(console::error); printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); - console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + console::set_display(console::reset); fflush(stdout); embd.resize(max_embd_size); } @@ -667,7 +664,7 @@ int main(int argc, char ** argv) { } // reset color to default if we there is no pending user input if (input_echo && (int)embd_inp.size() == n_consumed) { - console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + console::set_display(console::reset); } // if not currently processing queued inputs; @@ -693,7 +690,7 @@ int main(int argc, char ** argv) { if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) { if (params.interactive) { is_interacting = true; - console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + console::set_display(console::user_input); } is_antiprompt = true; fflush(stdout); @@ -714,7 +711,7 @@ int main(int argc, char ** argv) { is_interacting = true; printf("\n"); - console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + console::set_display(console::user_input); fflush(stdout); } else if (params.instruct) { is_interacting = true; @@ -739,12 +736,12 @@ int main(int argc, char ** argv) { std::string line; bool another_line = true; do { - another_line = console_readline(con_st, line); + another_line = console::readline(line, params.multiline_input); buffer += line; } while (another_line); // done taking input, reset color - console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + console::set_display(console::reset); // Add tokens to embd only if the input buffer is non-empty // Entering a empty line lets the user pass control back