From e1fa9569ba8ce276bc7801a3cebdcf8b1aa116ea Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Sat, 9 Mar 2024 02:57:09 -0700 Subject: [PATCH] server : add SSL support (#5926) * add cmake build toggle to enable ssl support in server Signed-off-by: Gabe Goodhart * add flags for ssl key/cert files and use SSLServer if set All SSL setup is hidden behind CPPHTTPLIB_OPENSSL_SUPPORT in the same way that the base httlib hides the SSL support Signed-off-by: Gabe Goodhart * Update readme for SSL support in server Signed-off-by: Gabe Goodhart * Add LLAMA_SERVER_SSL variable setup to top-level Makefile Signed-off-by: Gabe Goodhart --------- Signed-off-by: Gabe Goodhart --- Makefile | 4 ++ examples/server/CMakeLists.txt | 6 ++ examples/server/README.md | 26 ++++++++ examples/server/server.cpp | 108 ++++++++++++++++++++++----------- 4 files changed, 110 insertions(+), 34 deletions(-) diff --git a/Makefile b/Makefile index efce10bb8..aea969222 100644 --- a/Makefile +++ b/Makefile @@ -201,6 +201,10 @@ ifdef LLAMA_SERVER_VERBOSE MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE) endif +ifdef LLAMA_SERVER_SSL + MK_CPPFLAGS += -DCPPHTTPLIB_OPENSSL_SUPPORT + MK_LDFLAGS += -lssl -lcrypto +endif ifdef LLAMA_CODE_COVERAGE MK_CXXFLAGS += -fprofile-arcs -ftest-coverage -dumpbase '' diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index c21eba634..f94de1e99 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -1,5 +1,6 @@ set(TARGET server) option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON) +option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h) install(TARGETS ${TARGET} RUNTIME) @@ -7,6 +8,11 @@ target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$ ) target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) +if (LLAMA_SERVER_SSL) + find_package(OpenSSL REQUIRED) + target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto) + target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_OPENSSL_SUPPORT) +endif() if (WIN32) TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) endif() diff --git a/examples/server/README.md b/examples/server/README.md index 591f748f8..bf8c450b6 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -59,6 +59,10 @@ see https://github.com/ggerganov/llama.cpp/issues/1437 - `--log-disable`: Output logs to stdout only, default: enabled. - `--log-format FORMAT`: Define the log output to FORMAT: json or text (default: json) +**If compiled with `LLAMA_SERVER_SSL=ON`** +- `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key +- `--ssl-cert-file FNAME`: path to file a PEM-encoded SSL certificate + ## Build server is build alongside everything else from the root of the project @@ -75,6 +79,28 @@ server is build alongside everything else from the root of the project cmake --build . --config Release ``` +## Build with SSL + +server can also be built with SSL support using OpenSSL 3 + +- Using `make`: + + ```bash + # NOTE: For non-system openssl, use the following: + # CXXFLAGS="-I /path/to/openssl/include" + # LDFLAGS="-L /path/to/openssl/lib" + make LLAMA_SERVER_SSL=true server + ``` + +- Using `CMake`: + + ```bash + mkdir build + cd build + cmake .. -DLLAMA_SERVER_SSL=ON + make server + ``` + ## Quick Start To get started right away, run the following command, making sure to use the correct path for the model you have: diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6f4449984..c3b87c846 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -27,6 +27,7 @@ #include #include #include +#include using json = nlohmann::json; @@ -118,6 +119,11 @@ struct server_params { std::vector api_keys; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string ssl_key_file = ""; + std::string ssl_cert_file = ""; +#endif + bool slots_endpoint = true; bool metrics_endpoint = false; }; @@ -2142,6 +2148,10 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + printf(" --ssl-key-file FNAME path to file a PEM-encoded SSL private key\n"); + printf(" --ssl-cert-file FNAME path to file a PEM-encoded SSL certificate\n"); +#endif printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); @@ -2220,7 +2230,24 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, } } key_file.close(); - } else if (arg == "--timeout" || arg == "-to") { + + } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + else if (arg == "--ssl-key-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.ssl_key_file = argv[i]; + } else if (arg == "--ssl-cert-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.ssl_cert_file = argv[i]; + } +#endif + else if (arg == "--timeout" || arg == "-to") { if (++i >= argc) { invalid_param = true; break; @@ -2658,21 +2685,34 @@ int main(int argc, char ** argv) { {"system_info", llama_print_system_info()}, }); - httplib::Server svr; + std::unique_ptr svr; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (sparams.ssl_key_file != "" && sparams.ssl_cert_file != "") { + LOG_INFO("Running with SSL", {{"key", sparams.ssl_key_file}, {"cert", sparams.ssl_cert_file}}); + svr.reset( + new httplib::SSLServer(sparams.ssl_cert_file.c_str(), sparams.ssl_key_file.c_str()) + ); + } else { + LOG_INFO("Running without SSL", {}); + svr.reset(new httplib::Server()); + } +#else + svr.reset(new httplib::Server()); +#endif std::atomic state{SERVER_STATE_LOADING_MODEL}; - svr.set_default_headers({{"Server", "llama.cpp"}}); + svr->set_default_headers({{"Server", "llama.cpp"}}); // CORS preflight - svr.Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) { + svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "POST"); res.set_header("Access-Control-Allow-Headers", "*"); }); - svr.Get("/health", [&](const httplib::Request & req, httplib::Response & res) { + svr->Get("/health", [&](const httplib::Request & req, httplib::Response & res) { server_state current_state = state.load(); switch (current_state) { case SERVER_STATE_READY: @@ -2728,7 +2768,7 @@ int main(int argc, char ** argv) { }); if (sparams.slots_endpoint) { - svr.Get("/slots", [&](const httplib::Request &, httplib::Response & res) { + svr->Get("/slots", [&](const httplib::Request &, httplib::Response & res) { // request slots data using task queue server_task task; task.id = ctx_server.queue_tasks.get_new_id(); @@ -2749,7 +2789,7 @@ int main(int argc, char ** argv) { } if (sparams.metrics_endpoint) { - svr.Get("/metrics", [&](const httplib::Request &, httplib::Response & res) { + svr->Get("/metrics", [&](const httplib::Request &, httplib::Response & res) { // request slots data using task queue server_task task; task.id = ctx_server.queue_tasks.get_new_id(); @@ -2846,9 +2886,9 @@ int main(int argc, char ** argv) { }); } - svr.set_logger(log_server_request); + svr->set_logger(log_server_request); - svr.set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { + svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { const char fmt[] = "500 Internal Server Error\n%s"; char buf[BUFSIZ]; @@ -2864,7 +2904,7 @@ int main(int argc, char ** argv) { res.status = 500; }); - svr.set_error_handler([](const httplib::Request &, httplib::Response & res) { + svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { if (res.status == 401) { res.set_content("Unauthorized", "text/plain; charset=utf-8"); } @@ -2877,16 +2917,16 @@ int main(int argc, char ** argv) { }); // set timeouts and change hostname and port - svr.set_read_timeout (sparams.read_timeout); - svr.set_write_timeout(sparams.write_timeout); + svr->set_read_timeout (sparams.read_timeout); + svr->set_write_timeout(sparams.write_timeout); - if (!svr.bind_to_port(sparams.hostname, sparams.port)) { + if (!svr->bind_to_port(sparams.hostname, sparams.port)) { fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port); return 1; } // Set the base directory for serving static files - svr.set_base_dir(sparams.public_path); + svr->set_base_dir(sparams.public_path); std::unordered_map log_data; @@ -2947,30 +2987,30 @@ int main(int argc, char ** argv) { }; // this is only called if no index.html is found in the public --path - svr.Get("/", [](const httplib::Request &, httplib::Response & res) { + svr->Get("/", [](const httplib::Request &, httplib::Response & res) { res.set_content(reinterpret_cast(&index_html), index_html_len, "text/html; charset=utf-8"); return false; }); // this is only called if no index.js is found in the public --path - svr.Get("/index.js", [](const httplib::Request &, httplib::Response & res) { + svr->Get("/index.js", [](const httplib::Request &, httplib::Response & res) { res.set_content(reinterpret_cast(&index_js), index_js_len, "text/javascript; charset=utf-8"); return false; }); // this is only called if no index.html is found in the public --path - svr.Get("/completion.js", [](const httplib::Request &, httplib::Response & res) { + svr->Get("/completion.js", [](const httplib::Request &, httplib::Response & res) { res.set_content(reinterpret_cast(&completion_js), completion_js_len, "application/javascript; charset=utf-8"); return false; }); // this is only called if no index.html is found in the public --path - svr.Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) { + svr->Get("/json-schema-to-grammar.mjs", [](const httplib::Request &, httplib::Response & res) { res.set_content(reinterpret_cast(&json_schema_to_grammar_mjs), json_schema_to_grammar_mjs_len, "application/javascript; charset=utf-8"); return false; }); - svr.Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) { + svr->Get("/props", [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { { "user_name", ctx_server.name_user.c_str() }, @@ -3062,11 +3102,11 @@ int main(int argc, char ** argv) { } }; - svr.Post("/completion", completions); // legacy - svr.Post("/completions", completions); - svr.Post("/v1/completions", completions); + svr->Post("/completion", completions); // legacy + svr->Post("/completions", completions); + svr->Post("/v1/completions", completions); - svr.Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { + svr->Get("/v1/models", [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json models = { @@ -3161,10 +3201,10 @@ int main(int argc, char ** argv) { } }; - svr.Post("/chat/completions", chat_completions); - svr.Post("/v1/chat/completions", chat_completions); + svr->Post("/chat/completions", chat_completions); + svr->Post("/v1/chat/completions", chat_completions); - svr.Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { + svr->Post("/infill", [&ctx_server, &validate_api_key](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; @@ -3228,11 +3268,11 @@ int main(int argc, char ** argv) { } }); - svr.Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) { + svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) { return res.set_content("", "application/json; charset=utf-8"); }); - svr.Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { + svr->Post("/tokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); @@ -3244,7 +3284,7 @@ int main(int argc, char ** argv) { return res.set_content(data.dump(), "application/json; charset=utf-8"); }); - svr.Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { + svr->Post("/detokenize", [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); @@ -3258,7 +3298,7 @@ int main(int argc, char ** argv) { return res.set_content(data.dump(), "application/json; charset=utf-8"); }); - svr.Post("/embedding", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { + svr->Post("/embedding", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!params.embedding) { res.status = 501; @@ -3289,7 +3329,7 @@ int main(int argc, char ** argv) { return res.set_content(result.data.dump(), "application/json; charset=utf-8"); }); - svr.Post("/v1/embeddings", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { + svr->Post("/v1/embeddings", [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!params.embedding) { res.status = 501; @@ -3360,13 +3400,13 @@ int main(int argc, char ** argv) { sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); } log_data["n_threads_http"] = std::to_string(sparams.n_threads_http); - svr.new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); }; + svr->new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); }; LOG_INFO("HTTP server listening", log_data); // run the HTTP server in a thread - see comment below std::thread t([&]() { - if (!svr.listen_after_bind()) { + if (!svr->listen_after_bind()) { state.store(SERVER_STATE_ERROR); return 1; } @@ -3407,7 +3447,7 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.start_loop(); - svr.stop(); + svr->stop(); t.join(); llama_backend_free();