From 88ae8952b65cbf32eb1f5703681ea592e510e570 Mon Sep 17 00:00:00 2001 From: ShadovvBeast Date: Fri, 15 Dec 2023 13:49:01 +0200 Subject: [PATCH] server : add optional API Key Authentication example (#4441) * Add API key authentication for enhanced server-client security * server : to snake_case --------- Co-authored-by: Georgi Gerganov --- examples/server/public/completion.js | 3 +- examples/server/public/index.html | 7 ++- examples/server/server.cpp | 70 ++++++++++++++++++++++++---- 3 files changed, 70 insertions(+), 10 deletions(-) diff --git a/examples/server/public/completion.js b/examples/server/public/completion.js index c281f0fbd..6e2b99565 100644 --- a/examples/server/public/completion.js +++ b/examples/server/public/completion.js @@ -34,7 +34,8 @@ export async function* llama(prompt, params = {}, config = {}) { headers: { 'Connection': 'keep-alive', 'Content-Type': 'application/json', - 'Accept': 'text/event-stream' + 'Accept': 'text/event-stream', + ...(params.api_key ? {'Authorization': `Bearer ${params.api_key}`} : {}) }, signal: controller.signal, }); diff --git a/examples/server/public/index.html b/examples/server/public/index.html index 451fd4a3b..07d779d20 100644 --- a/examples/server/public/index.html +++ b/examples/server/public/index.html @@ -235,7 +235,8 @@ grammar: '', n_probs: 0, // no completion_probabilities, image_data: [], - cache_prompt: true + cache_prompt: true, + api_key: '' }) /* START: Support for storing prompt templates and parameters in browsers LocalStorage */ @@ -790,6 +791,10 @@
${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
+
+ + +
` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 39d1e83d1..5f93dcb66 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -36,6 +36,7 @@ using json = nlohmann::json; struct server_params { std::string hostname = "127.0.0.1"; + std::string api_key; std::string public_path = "examples/server/public"; int32_t port = 8080; int32_t read_timeout = 600; @@ -1953,6 +1954,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); + printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); @@ -2002,6 +2004,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } sparams.public_path = argv[i]; } + else if (arg == "--api-key") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + sparams.api_key = argv[i]; + } else if (arg == "--timeout" || arg == "-to") { if (++i >= argc) @@ -2669,6 +2680,32 @@ int main(int argc, char **argv) httplib::Server svr; + // Middleware for API key validation + auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { + // If API key is not set, skip validation + if (sparams.api_key.empty()) { + return true; + } + + // Check for API key in the header + auto auth_header = req.get_header_value("Authorization"); + std::string prefix = "Bearer "; + if (auth_header.substr(0, prefix.size()) == prefix) { + std::string received_api_key = auth_header.substr(prefix.size()); + if (received_api_key == sparams.api_key) { + return true; // API key is valid + } + } + + // API key is invalid or not provided + res.set_content("Unauthorized: Invalid API Key", "text/plain"); + res.status = 401; // Unauthorized + + LOG_WARNING("Unauthorized: Invalid API Key", {}); + + return false; + }; + svr.set_default_headers({{"Server", "llama.cpp"}, {"Access-Control-Allow-Origin", "*"}, {"Access-Control-Allow-Headers", "content-type"}}); @@ -2711,8 +2748,11 @@ int main(int argc, char **argv) res.set_content(data.dump(), "application/json"); }); - svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res) + svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + if (!validate_api_key(req, res)) { + return; + } json data = json::parse(req.body); const int task_id = llama.request_completion(data, false, false, -1); if (!json_value(data, "stream", false)) { @@ -2799,8 +2839,11 @@ int main(int argc, char **argv) }); // TODO: add mount point without "/v1" prefix -- how? - svr.Post("/v1/chat/completions", [&llama](const httplib::Request &req, httplib::Response &res) + svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + if (!validate_api_key(req, res)) { + return; + } json data = oaicompat_completion_params_parse(json::parse(req.body)); const int task_id = llama.request_completion(data, false, false, -1); @@ -2869,8 +2912,11 @@ int main(int argc, char **argv) } }); - svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res) + svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + if (!validate_api_key(req, res)) { + return; + } json data = json::parse(req.body); const int task_id = llama.request_completion(data, true, false, -1); if (!json_value(data, "stream", false)) { @@ -3005,11 +3051,15 @@ int main(int argc, char **argv) svr.set_error_handler([](const httplib::Request &, httplib::Response &res) { + if (res.status == 401) + { + res.set_content("Unauthorized", "text/plain"); + } if (res.status == 400) { res.set_content("Invalid request", "text/plain"); } - else if (res.status != 500) + else if (res.status == 404) { res.set_content("File Not Found", "text/plain"); res.status = 404; @@ -3032,11 +3082,15 @@ int main(int argc, char **argv) // to make it ctrl+clickable: LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); - LOG_INFO("HTTP server listening", { - {"hostname", sparams.hostname}, - {"port", sparams.port}, - }); + std::unordered_map log_data; + log_data["hostname"] = sparams.hostname; + log_data["port"] = std::to_string(sparams.port); + if (!sparams.api_key.empty()) { + log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4); + } + + LOG_INFO("HTTP server listening", log_data); // run the HTTP server in a thread - see comment below std::thread t([&]() {