|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "utils.hpp" |
| 4 | +#include "common.h" |
| 5 | + |
| 6 | +#include <functional> |
| 7 | +#include <string> |
| 8 | +#include <thread> |
| 9 | + |
| 10 | +// auto generated files (see README.md for details) |
| 11 | +#include "index.html.gz.hpp" |
| 12 | +#include "loading.html.hpp" |
| 13 | + |
| 14 | +// generator-like API for HTTP response generation |
| 15 | +struct server_http_resgen { |
| 16 | + std::string content_type = "application/json; charset=utf-8"; |
| 17 | + int status = 200; |
| 18 | + std::string data; |
| 19 | + |
| 20 | + // if is_stream is true, next() will return true until the stream ends |
| 21 | + // the data member will contain the next chunk of data to send |
| 22 | + // TODO: move this to a virtual function once we have proper polymorphism support |
| 23 | + std::function<bool()> next = nullptr; |
| 24 | + bool is_stream() const { |
| 25 | + return next != nullptr; |
| 26 | + } |
| 27 | + |
| 28 | + virtual ~server_http_resgen() = default; |
| 29 | +}; |
| 30 | + |
| 31 | +// unique pointer, used by set_chunked_content_provider |
| 32 | +// we need to use unique_ptr because httplib requires the stream provider to be stored in heap |
| 33 | +using server_http_resgen_ptr = std::unique_ptr<server_http_resgen>; |
| 34 | + |
| 35 | +struct server_http_request { |
| 36 | + std::unordered_map<std::string, std::string> query_params; |
| 37 | + json body; |
| 38 | + const std::function<bool()> & should_stop; |
| 39 | +}; |
| 40 | + |
| 41 | +struct server_http_context { |
| 42 | + std::thread thread; |
| 43 | + std::unique_ptr<httplib::Server> svr; |
| 44 | + std::atomic<bool> is_ready = false; |
| 45 | + |
| 46 | + std::string path_prefix; |
| 47 | + std::string hostname; |
| 48 | + int port; |
| 49 | + |
| 50 | + bool init(const common_params & params); |
| 51 | + bool start(); |
| 52 | + void stop(); |
| 53 | + |
| 54 | + using handler_t = std::function<server_http_resgen_ptr(const server_http_request & req)>; |
| 55 | + void get(const std::string &, handler_t); |
| 56 | + void post(const std::string &, handler_t); |
| 57 | +}; |
| 58 | + |
| 59 | +// implementation details |
| 60 | + |
| 61 | +static void log_server_request(const httplib::Request & req, const httplib::Response & res) { |
| 62 | + // skip GH copilot requests when using default port |
| 63 | + if (req.path == "/v1/health") { |
| 64 | + return; |
| 65 | + } |
| 66 | + |
| 67 | + // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch |
| 68 | + |
| 69 | + SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); |
| 70 | + |
| 71 | + SRV_DBG("request: %s\n", req.body.c_str()); |
| 72 | + SRV_DBG("response: %s\n", res.body.c_str()); |
| 73 | +} |
| 74 | + |
| 75 | +bool server_http_context::init(const common_params & params) { |
| 76 | + path_prefix = params.api_prefix; |
| 77 | + port = params.port; |
| 78 | + hostname = params.hostname; |
| 79 | + |
| 80 | +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT |
| 81 | + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { |
| 82 | + LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); |
| 83 | + svr.reset( |
| 84 | + new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) |
| 85 | + ); |
| 86 | + } else { |
| 87 | + LOG_INF("Running without SSL\n"); |
| 88 | + svr.reset(new httplib::Server()); |
| 89 | + } |
| 90 | +#else |
| 91 | + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { |
| 92 | + LOG_ERR("Server is built without SSL support\n"); |
| 93 | + return false; |
| 94 | + } |
| 95 | + svr.reset(new httplib::Server()); |
| 96 | +#endif |
| 97 | + |
| 98 | + svr->set_default_headers({{"Server", "llama.cpp"}}); |
| 99 | + svr->set_logger(log_server_request); |
| 100 | + svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { |
| 101 | + std::string message; |
| 102 | + try { |
| 103 | + std::rethrow_exception(ep); |
| 104 | + } catch (const std::exception & e) { |
| 105 | + message = e.what(); |
| 106 | + } catch (...) { |
| 107 | + message = "Unknown Exception"; |
| 108 | + } |
| 109 | + |
| 110 | + // FIXME |
| 111 | + GGML_UNUSED(res); |
| 112 | + GGML_UNUSED(message); |
| 113 | + // try { |
| 114 | + // json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); |
| 115 | + // LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); |
| 116 | + // res_error(res, formatted_error); |
| 117 | + // } catch (const std::exception & e) { |
| 118 | + // LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); |
| 119 | + // } |
| 120 | + }); |
| 121 | + |
| 122 | + svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { |
| 123 | + if (res.status == 404) { |
| 124 | + // FIXME |
| 125 | + //res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); |
| 126 | + res.set_content("404 Not Found", "text/plain"); |
| 127 | + } |
| 128 | + // for other error codes, we skip processing here because it's already done by res_error() |
| 129 | + }); |
| 130 | + |
| 131 | + // set timeouts and change hostname and port |
| 132 | + svr->set_read_timeout (params.timeout_read); |
| 133 | + svr->set_write_timeout(params.timeout_write); |
| 134 | + |
| 135 | + if (params.api_keys.size() == 1) { |
| 136 | + auto key = params.api_keys[0]; |
| 137 | + std::string substr = key.substr(std::max((int)(key.length() - 4), 0)); |
| 138 | + LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str()); |
| 139 | + } else if (params.api_keys.size() > 1) { |
| 140 | + LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size()); |
| 141 | + } |
| 142 | + |
| 143 | + // |
| 144 | + // Middlewares |
| 145 | + // |
| 146 | + |
| 147 | + auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) { |
| 148 | + static const std::unordered_set<std::string> public_endpoints = { |
| 149 | + "/health", |
| 150 | + "/v1/health", |
| 151 | + "/models", |
| 152 | + "/v1/models", |
| 153 | + "/api/tags" |
| 154 | + }; |
| 155 | + |
| 156 | + // If API key is not set, skip validation |
| 157 | + if (api_keys.empty()) { |
| 158 | + return true; |
| 159 | + } |
| 160 | + |
| 161 | + // If path is public or is static file, skip validation |
| 162 | + if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { |
| 163 | + return true; |
| 164 | + } |
| 165 | + |
| 166 | + // Check for API key in the header |
| 167 | + auto auth_header = req.get_header_value("Authorization"); |
| 168 | + |
| 169 | + std::string prefix = "Bearer "; |
| 170 | + if (auth_header.substr(0, prefix.size()) == prefix) { |
| 171 | + std::string received_api_key = auth_header.substr(prefix.size()); |
| 172 | + if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) { |
| 173 | + return true; // API key is valid |
| 174 | + } |
| 175 | + } |
| 176 | + |
| 177 | + // API key is invalid or not provided |
| 178 | + //res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); |
| 179 | + // FIXME |
| 180 | + res.status = 401; |
| 181 | + res.set_content("Unauthorized: Invalid API Key", "text/plain"); |
| 182 | + |
| 183 | + LOG_WRN("Unauthorized: Invalid API Key\n"); |
| 184 | + |
| 185 | + return false; |
| 186 | + }; |
| 187 | + |
| 188 | + auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) { |
| 189 | + bool ready = is_ready.load(); |
| 190 | + if (!ready) { |
| 191 | + auto tmp = string_split<std::string>(req.path, '.'); |
| 192 | + if (req.path == "/" || tmp.back() == "html") { |
| 193 | + res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8"); |
| 194 | + res.status = 503; |
| 195 | + } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { |
| 196 | + // allow the models endpoint to be accessed during loading |
| 197 | + return true; |
| 198 | + } else { |
| 199 | + // FIXME |
| 200 | + //res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); |
| 201 | + res.status = 503; |
| 202 | + res.set_content("503 Service Unavailable: Loading model", "text/plain"); |
| 203 | + } |
| 204 | + return false; |
| 205 | + } |
| 206 | + return true; |
| 207 | + }; |
| 208 | + |
| 209 | + // register server middlewares |
| 210 | + svr->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) { |
| 211 | + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); |
| 212 | + // If this is OPTIONS request, skip validation because browsers don't include Authorization header |
| 213 | + if (req.method == "OPTIONS") { |
| 214 | + res.set_header("Access-Control-Allow-Credentials", "true"); |
| 215 | + res.set_header("Access-Control-Allow-Methods", "GET, POST"); |
| 216 | + res.set_header("Access-Control-Allow-Headers", "*"); |
| 217 | + res.set_content("", "text/html"); // blank response, no data |
| 218 | + return httplib::Server::HandlerResponse::Handled; // skip further processing |
| 219 | + } |
| 220 | + if (!middleware_server_state(req, res)) { |
| 221 | + return httplib::Server::HandlerResponse::Handled; |
| 222 | + } |
| 223 | + if (!middleware_validate_api_key(req, res)) { |
| 224 | + return httplib::Server::HandlerResponse::Handled; |
| 225 | + } |
| 226 | + return httplib::Server::HandlerResponse::Unhandled; |
| 227 | + }); |
| 228 | + |
| 229 | + int n_threads_http = params.n_threads_http; |
| 230 | + if (n_threads_http < 1) { |
| 231 | + // +2 threads for monitoring endpoints |
| 232 | + n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); |
| 233 | + } |
| 234 | + LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http); |
| 235 | + svr->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); }; |
| 236 | + |
| 237 | + // |
| 238 | + // Web UI setup |
| 239 | + // |
| 240 | + |
| 241 | + if (!params.webui) { |
| 242 | + LOG_INF("Web UI is disabled\n"); |
| 243 | + } else { |
| 244 | + // register static assets routes |
| 245 | + if (!params.public_path.empty()) { |
| 246 | + // Set the base directory for serving static files |
| 247 | + bool is_found = svr->set_mount_point(params.api_prefix + "/", params.public_path); |
| 248 | + if (!is_found) { |
| 249 | + LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); |
| 250 | + return 1; |
| 251 | + } |
| 252 | + } else { |
| 253 | + // using embedded static index.html |
| 254 | + svr->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) { |
| 255 | + if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { |
| 256 | + res.set_content("Error: gzip is not supported by this browser", "text/plain"); |
| 257 | + } else { |
| 258 | + res.set_header("Content-Encoding", "gzip"); |
| 259 | + // COEP and COOP headers, required by pyodide (python interpreter) |
| 260 | + res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); |
| 261 | + res.set_header("Cross-Origin-Opener-Policy", "same-origin"); |
| 262 | + res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); |
| 263 | + } |
| 264 | + return false; |
| 265 | + }); |
| 266 | + } |
| 267 | + } |
| 268 | + return true; |
| 269 | +} |
| 270 | + |
| 271 | +bool server_http_context::start() { |
| 272 | + // Bind and listen |
| 273 | + |
| 274 | + bool was_bound = false; |
| 275 | + bool is_sock = false; |
| 276 | + if (string_ends_with(std::string(hostname), ".sock")) { |
| 277 | + is_sock = true; |
| 278 | + LOG_INF("%s: setting address family to AF_UNIX\n", __func__); |
| 279 | + svr->set_address_family(AF_UNIX); |
| 280 | + // bind_to_port requires a second arg, any value other than 0 should |
| 281 | + // simply get ignored |
| 282 | + was_bound = svr->bind_to_port(hostname, 8080); |
| 283 | + } else { |
| 284 | + LOG_INF("%s: binding port with default address family\n", __func__); |
| 285 | + // bind HTTP listen port |
| 286 | + if (port == 0) { |
| 287 | + int bound_port = svr->bind_to_any_port(hostname); |
| 288 | + if ((was_bound = (bound_port >= 0))) { |
| 289 | + port = bound_port; |
| 290 | + } |
| 291 | + } else { |
| 292 | + was_bound = svr->bind_to_port(hostname, port); |
| 293 | + } |
| 294 | + } |
| 295 | + |
| 296 | + if (!was_bound) { |
| 297 | + LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port); |
| 298 | + return false; |
| 299 | + } |
| 300 | + |
| 301 | + // run the HTTP server in a thread |
| 302 | + thread = std::thread([this]() { svr->listen_after_bind(); }); |
| 303 | + svr->wait_until_ready(); |
| 304 | + |
| 305 | + LOG_INF("%s: server is listening on %s\n", __func__, |
| 306 | + is_sock ? string_format("unix://%s", hostname.c_str()).c_str() : |
| 307 | + string_format("http://%s:%d", hostname.c_str(), port).c_str()); |
| 308 | + return true; |
| 309 | +} |
| 310 | + |
| 311 | +void server_http_context::stop() { |
| 312 | + if (svr) { |
| 313 | + svr->stop(); |
| 314 | + } |
| 315 | +} |
| 316 | + |
| 317 | +void server_http_context::get(const std::string & path, server_http_context::handler_t handler) { |
| 318 | + svr->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { |
| 319 | + server_http_resgen_ptr response = handler(server_http_request{ |
| 320 | + req.path_params, |
| 321 | + json{}, |
| 322 | + req.is_connection_closed |
| 323 | + }); |
| 324 | + GGML_ASSERT(!response->is_stream() && "not supported for GET method"); |
| 325 | + res.status = response->status; |
| 326 | + res.set_content(response->data, response->content_type); |
| 327 | + }); |
| 328 | +} |
| 329 | + |
| 330 | +void server_http_context::post(const std::string & path, server_http_context::handler_t handler) { |
| 331 | + svr->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { |
| 332 | + server_http_resgen_ptr response = handler(server_http_request{ |
| 333 | + req.path_params, |
| 334 | + json::parse(req.body.empty() ? "{}" : req.body), |
| 335 | + req.is_connection_closed |
| 336 | + }); |
| 337 | + if (response->is_stream()) { |
| 338 | + res.status = response->status; |
| 339 | + std::string content_type = response->content_type; |
| 340 | + // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it |
| 341 | + std::shared_ptr<server_http_resgen> r_ptr = std::move(response); |
| 342 | + const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool { |
| 343 | + // TODO: maybe handle sink.write unsuccessful case? for now, we rely on is_connection_closed() |
| 344 | + sink.write(response->data.data(), response->data.size()); |
| 345 | + SRV_DBG("http: streamed chunk: %s\n", response->data.c_str()); |
| 346 | + if (!response->next()) { |
| 347 | + SRV_DBG("%s", "http: stream ended\n"); |
| 348 | + sink.done(); |
| 349 | + return false; // end of stream |
| 350 | + } |
| 351 | + return true; |
| 352 | + }; |
| 353 | + const auto on_complete = [response = r_ptr](bool) mutable { |
| 354 | + response.reset(); // trigger the destruction of the response object |
| 355 | + }; |
| 356 | + res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete); |
| 357 | + } else { |
| 358 | + res.status = response->status; |
| 359 | + res.set_content(response->data, response->content_type); |
| 360 | + } |
| 361 | + }); |
| 362 | +} |
0 commit comments