@@ -31,6 +31,7 @@ struct HttpServerState {
31
31
std::atomic<bool > is_running;
32
32
DatabaseInstance* db_instance;
33
33
unique_ptr<Allocator> allocator;
34
+ std::string auth_token;
34
35
35
36
HttpServerState () : is_running(false ), db_instance(nullptr ) {}
36
37
};
@@ -129,6 +130,51 @@ static std::string ConvertResultToJSON(MaterializedQueryResult &result, ReqStats
129
130
return json_output;
130
131
}
131
132
133
+ // New: Base64 decoding function
134
+ std::string base64_decode (const std::string &in) {
135
+ std::string out;
136
+ std::vector<int > T (256 , -1 );
137
+ for (int i = 0 ; i < 64 ; i++)
138
+ T[" ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" [i]] = i;
139
+
140
+ int val = 0 , valb = -8 ;
141
+ for (unsigned char c : in) {
142
+ if (T[c] == -1 ) break ;
143
+ val = (val << 6 ) + T[c];
144
+ valb += 6 ;
145
+ if (valb >= 0 ) {
146
+ out.push_back (char ((val >> valb) & 0xFF ));
147
+ valb -= 8 ;
148
+ }
149
+ }
150
+ return out;
151
+ }
152
+
153
+ // Auth Check
154
+ bool IsAuthenticated (const duckdb_httplib_openssl::Request& req) {
155
+ if (global_state.auth_token .empty ()) {
156
+ return true ; // No authentication required if no token is set
157
+ }
158
+
159
+ // Check for X-API-Key header
160
+ auto api_key = req.get_header_value (" X-API-Key" );
161
+ if (!api_key.empty () && api_key == global_state.auth_token ) {
162
+ return true ;
163
+ }
164
+
165
+ // Check for Basic Auth
166
+ auto auth = req.get_header_value (" Authorization" );
167
+ if (!auth.empty () && auth.compare (0 , 6 , " Basic " ) == 0 ) {
168
+ std::string decoded_auth = base64_decode (auth.substr (6 ));
169
+ if (decoded_auth == global_state.auth_token ) {
170
+ return true ;
171
+ }
172
+ }
173
+
174
+ return false ;
175
+ }
176
+
177
+
132
178
// Convert the query result to NDJSON (JSONEachRow) format
133
179
static std::string ConvertResultToNDJSON (MaterializedQueryResult &result) {
134
180
std::string ndjson_output;
@@ -208,6 +254,13 @@ static void HandleQuery(const string& query, duckdb_httplib_openssl::Response& r
208
254
void HandleHttpRequest (const duckdb_httplib_openssl::Request& req, duckdb_httplib_openssl::Response& res) {
209
255
std::string query;
210
256
257
+ // Check authentication
258
+ if (!IsAuthenticated (req)) {
259
+ res.status = 401 ;
260
+ res.set_content (" Unauthorized" , " text/plain" );
261
+ return ;
262
+ }
263
+
211
264
// CORS allow
212
265
res.set_header (" Access-Control-Allow-Origin" , " *" );
213
266
res.set_header (" Access-Control-Allow-Methods" , " GET, POST, OPTIONS, PUT" );
@@ -295,14 +348,15 @@ void HandleHttpRequest(const duckdb_httplib_openssl::Request& req, duckdb_httpli
295
348
}
296
349
}
297
350
298
- void HttpServerStart (DatabaseInstance& db, string_t host, int32_t port) {
351
+ void HttpServerStart (DatabaseInstance& db, string_t host, int32_t port, string_t auth = string_t () ) {
299
352
if (global_state.is_running ) {
300
353
throw IOException (" HTTP server is already running" );
301
354
}
302
355
303
356
global_state.db_instance = &db;
304
357
global_state.server = make_uniq<duckdb_httplib_openssl::Server>();
305
358
global_state.is_running = true ;
359
+ global_state.auth_token = auth.GetString ();
306
360
307
361
// CORS Preflight
308
362
global_state.server ->Options (" /" ,
@@ -359,17 +413,19 @@ static void HttpServerCleanup() {
359
413
360
414
static void LoadInternal (DatabaseInstance &instance) {
361
415
auto httpserve_start = ScalarFunction (" httpserve_start" ,
362
- {LogicalType::VARCHAR, LogicalType::INTEGER},
416
+ {LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR },
363
417
LogicalType::VARCHAR,
364
418
[&](DataChunk &args, ExpressionState &state, Vector &result) {
365
419
auto &host_vector = args.data [0 ];
366
420
auto &port_vector = args.data [1 ];
421
+ auto &auth_vector = args.data [2 ];
367
422
368
423
UnaryExecutor::Execute<string_t , string_t >(
369
424
host_vector, result, args.size (),
370
425
[&](string_t host) {
371
426
auto port = ((int32_t *)port_vector.GetData ())[0 ];
372
- HttpServerStart (instance, host, port);
427
+ auto auth = ((string_t *)auth_vector.GetData ())[0 ];
428
+ HttpServerStart (instance, host, port, auth);
373
429
return StringVector::AddString (result, " HTTP server started on " + host.GetString () + " :" + std::to_string (port));
374
430
});
375
431
});
0 commit comments