Skip to content

Commit 45b2fe1

Browse files
committed
server: split HTTP into its own interface
1 parent 00c9408 commit 45b2fe1

File tree

3 files changed

+531
-1201
lines changed

3 files changed

+531
-1201
lines changed

tools/server/server-http.h

Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
#pragma once
2+
3+
#include "utils.hpp"
4+
#include "common.h"
5+
6+
#include <functional>
7+
#include <string>
8+
#include <thread>
9+
10+
// auto generated files (see README.md for details)
11+
#include "index.html.gz.hpp"
12+
#include "loading.html.hpp"
13+
14+
// generator-like API for HTTP response generation
15+
struct server_http_resgen {
16+
std::string content_type = "application/json; charset=utf-8";
17+
int status = 200;
18+
std::string data;
19+
20+
// if is_stream is true, next() will return true until the stream ends
21+
// the data member will contain the next chunk of data to send
22+
// TODO: move this to a virtual function once we have proper polymorphism support
23+
std::function<bool()> next = nullptr;
24+
bool is_stream() const {
25+
return next != nullptr;
26+
}
27+
28+
virtual ~server_http_resgen() = default;
29+
};
30+
31+
// unique pointer, used by set_chunked_content_provider
32+
// we need to use unique_ptr because httplib requires the stream provider to be stored in heap
33+
using server_http_resgen_ptr = std::unique_ptr<server_http_resgen>;
34+
35+
struct server_http_request {
36+
std::unordered_map<std::string, std::string> query_params;
37+
json body;
38+
const std::function<bool()> & should_stop;
39+
};
40+
41+
struct server_http_context {
42+
std::thread thread;
43+
std::unique_ptr<httplib::Server> svr;
44+
std::atomic<bool> is_ready = false;
45+
46+
std::string path_prefix;
47+
std::string hostname;
48+
int port;
49+
50+
bool init(const common_params & params);
51+
bool start();
52+
void stop();
53+
54+
using handler_t = std::function<server_http_resgen_ptr(const server_http_request & req)>;
55+
void get(const std::string &, handler_t);
56+
void post(const std::string &, handler_t);
57+
};
58+
59+
// implementation details
60+
61+
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
62+
// skip GH copilot requests when using default port
63+
if (req.path == "/v1/health") {
64+
return;
65+
}
66+
67+
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
68+
69+
SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
70+
71+
SRV_DBG("request: %s\n", req.body.c_str());
72+
SRV_DBG("response: %s\n", res.body.c_str());
73+
}
74+
75+
bool server_http_context::init(const common_params & params) {
76+
path_prefix = params.api_prefix;
77+
port = params.port;
78+
hostname = params.hostname;
79+
80+
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
81+
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
82+
LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
83+
svr.reset(
84+
new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
85+
);
86+
} else {
87+
LOG_INF("Running without SSL\n");
88+
svr.reset(new httplib::Server());
89+
}
90+
#else
91+
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
92+
LOG_ERR("Server is built without SSL support\n");
93+
return false;
94+
}
95+
svr.reset(new httplib::Server());
96+
#endif
97+
98+
svr->set_default_headers({{"Server", "llama.cpp"}});
99+
svr->set_logger(log_server_request);
100+
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
101+
std::string message;
102+
try {
103+
std::rethrow_exception(ep);
104+
} catch (const std::exception & e) {
105+
message = e.what();
106+
} catch (...) {
107+
message = "Unknown Exception";
108+
}
109+
110+
// FIXME
111+
GGML_UNUSED(res);
112+
GGML_UNUSED(message);
113+
// try {
114+
// json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
115+
// LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
116+
// res_error(res, formatted_error);
117+
// } catch (const std::exception & e) {
118+
// LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str());
119+
// }
120+
});
121+
122+
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
123+
if (res.status == 404) {
124+
// FIXME
125+
//res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
126+
res.set_content("404 Not Found", "text/plain");
127+
}
128+
// for other error codes, we skip processing here because it's already done by res_error()
129+
});
130+
131+
// set timeouts and change hostname and port
132+
svr->set_read_timeout (params.timeout_read);
133+
svr->set_write_timeout(params.timeout_write);
134+
135+
if (params.api_keys.size() == 1) {
136+
auto key = params.api_keys[0];
137+
std::string substr = key.substr(std::max((int)(key.length() - 4), 0));
138+
LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str());
139+
} else if (params.api_keys.size() > 1) {
140+
LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size());
141+
}
142+
143+
//
144+
// Middlewares
145+
//
146+
147+
auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) {
148+
static const std::unordered_set<std::string> public_endpoints = {
149+
"/health",
150+
"/v1/health",
151+
"/models",
152+
"/v1/models",
153+
"/api/tags"
154+
};
155+
156+
// If API key is not set, skip validation
157+
if (api_keys.empty()) {
158+
return true;
159+
}
160+
161+
// If path is public or is static file, skip validation
162+
if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
163+
return true;
164+
}
165+
166+
// Check for API key in the header
167+
auto auth_header = req.get_header_value("Authorization");
168+
169+
std::string prefix = "Bearer ";
170+
if (auth_header.substr(0, prefix.size()) == prefix) {
171+
std::string received_api_key = auth_header.substr(prefix.size());
172+
if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) {
173+
return true; // API key is valid
174+
}
175+
}
176+
177+
// API key is invalid or not provided
178+
//res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
179+
// FIXME
180+
res.status = 401;
181+
res.set_content("Unauthorized: Invalid API Key", "text/plain");
182+
183+
LOG_WRN("Unauthorized: Invalid API Key\n");
184+
185+
return false;
186+
};
187+
188+
auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
189+
bool ready = is_ready.load();
190+
if (!ready) {
191+
auto tmp = string_split<std::string>(req.path, '.');
192+
if (req.path == "/" || tmp.back() == "html") {
193+
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
194+
res.status = 503;
195+
} else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") {
196+
// allow the models endpoint to be accessed during loading
197+
return true;
198+
} else {
199+
// FIXME
200+
//res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
201+
res.status = 503;
202+
res.set_content("503 Service Unavailable: Loading model", "text/plain");
203+
}
204+
return false;
205+
}
206+
return true;
207+
};
208+
209+
// register server middlewares
210+
svr->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) {
211+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
212+
// If this is OPTIONS request, skip validation because browsers don't include Authorization header
213+
if (req.method == "OPTIONS") {
214+
res.set_header("Access-Control-Allow-Credentials", "true");
215+
res.set_header("Access-Control-Allow-Methods", "GET, POST");
216+
res.set_header("Access-Control-Allow-Headers", "*");
217+
res.set_content("", "text/html"); // blank response, no data
218+
return httplib::Server::HandlerResponse::Handled; // skip further processing
219+
}
220+
if (!middleware_server_state(req, res)) {
221+
return httplib::Server::HandlerResponse::Handled;
222+
}
223+
if (!middleware_validate_api_key(req, res)) {
224+
return httplib::Server::HandlerResponse::Handled;
225+
}
226+
return httplib::Server::HandlerResponse::Unhandled;
227+
});
228+
229+
int n_threads_http = params.n_threads_http;
230+
if (n_threads_http < 1) {
231+
// +2 threads for monitoring endpoints
232+
n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
233+
}
234+
LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http);
235+
svr->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); };
236+
237+
//
238+
// Web UI setup
239+
//
240+
241+
if (!params.webui) {
242+
LOG_INF("Web UI is disabled\n");
243+
} else {
244+
// register static assets routes
245+
if (!params.public_path.empty()) {
246+
// Set the base directory for serving static files
247+
bool is_found = svr->set_mount_point(params.api_prefix + "/", params.public_path);
248+
if (!is_found) {
249+
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
250+
return 1;
251+
}
252+
} else {
253+
// using embedded static index.html
254+
svr->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
255+
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
256+
res.set_content("Error: gzip is not supported by this browser", "text/plain");
257+
} else {
258+
res.set_header("Content-Encoding", "gzip");
259+
// COEP and COOP headers, required by pyodide (python interpreter)
260+
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
261+
res.set_header("Cross-Origin-Opener-Policy", "same-origin");
262+
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
263+
}
264+
return false;
265+
});
266+
}
267+
}
268+
return true;
269+
}
270+
271+
bool server_http_context::start() {
272+
// Bind and listen
273+
274+
bool was_bound = false;
275+
bool is_sock = false;
276+
if (string_ends_with(std::string(hostname), ".sock")) {
277+
is_sock = true;
278+
LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
279+
svr->set_address_family(AF_UNIX);
280+
// bind_to_port requires a second arg, any value other than 0 should
281+
// simply get ignored
282+
was_bound = svr->bind_to_port(hostname, 8080);
283+
} else {
284+
LOG_INF("%s: binding port with default address family\n", __func__);
285+
// bind HTTP listen port
286+
if (port == 0) {
287+
int bound_port = svr->bind_to_any_port(hostname);
288+
if ((was_bound = (bound_port >= 0))) {
289+
port = bound_port;
290+
}
291+
} else {
292+
was_bound = svr->bind_to_port(hostname, port);
293+
}
294+
}
295+
296+
if (!was_bound) {
297+
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port);
298+
return false;
299+
}
300+
301+
// run the HTTP server in a thread
302+
thread = std::thread([this]() { svr->listen_after_bind(); });
303+
svr->wait_until_ready();
304+
305+
LOG_INF("%s: server is listening on %s\n", __func__,
306+
is_sock ? string_format("unix://%s", hostname.c_str()).c_str() :
307+
string_format("http://%s:%d", hostname.c_str(), port).c_str());
308+
return true;
309+
}
310+
311+
void server_http_context::stop() {
312+
if (svr) {
313+
svr->stop();
314+
}
315+
}
316+
317+
void server_http_context::get(const std::string & path, server_http_context::handler_t handler) {
318+
svr->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
319+
server_http_resgen_ptr response = handler(server_http_request{
320+
req.path_params,
321+
json{},
322+
req.is_connection_closed
323+
});
324+
GGML_ASSERT(!response->is_stream() && "not supported for GET method");
325+
res.status = response->status;
326+
res.set_content(response->data, response->content_type);
327+
});
328+
}
329+
330+
void server_http_context::post(const std::string & path, server_http_context::handler_t handler) {
331+
svr->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
332+
server_http_resgen_ptr response = handler(server_http_request{
333+
req.path_params,
334+
json::parse(req.body.empty() ? "{}" : req.body),
335+
req.is_connection_closed
336+
});
337+
if (response->is_stream()) {
338+
res.status = response->status;
339+
std::string content_type = response->content_type;
340+
// convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
341+
std::shared_ptr<server_http_resgen> r_ptr = std::move(response);
342+
const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
343+
// TODO: maybe handle sink.write unsuccessful case? for now, we rely on is_connection_closed()
344+
sink.write(response->data.data(), response->data.size());
345+
SRV_DBG("http: streamed chunk: %s\n", response->data.c_str());
346+
if (!response->next()) {
347+
SRV_DBG("%s", "http: stream ended\n");
348+
sink.done();
349+
return false; // end of stream
350+
}
351+
return true;
352+
};
353+
const auto on_complete = [response = r_ptr](bool) mutable {
354+
response.reset(); // trigger the destruction of the response object
355+
};
356+
res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
357+
} else {
358+
res.status = response->status;
359+
res.set_content(response->data, response->content_type);
360+
}
361+
});
362+
}

0 commit comments

Comments
 (0)