Skip to content

Commit 6b7e6d4

Browse files
committed
test: Add unit tests for GatewayError handling in InferenceGatewayClient
Signed-off-by: Eden Reich <eden.reich@gmail.com>
1 parent fe85dcb commit 6b7e6d4

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

src/lib.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,56 @@ mod tests {
298298
use super::*;
299299
use mockito::{Matcher, Server};
300300

301+
#[test]
302+
fn test_gateway_errors() {
303+
let mut server = Server::new();
304+
305+
// Test unauthorized error
306+
let unauthorized_mock = server
307+
.mock("GET", "/llms")
308+
.with_status(401)
309+
.with_header("content-type", "application/json")
310+
.with_body(r#"{"error":"Invalid token"}"#)
311+
.create();
312+
313+
let client = InferenceGatewayClient::new(&server.url());
314+
match client.list_models() {
315+
Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
316+
_ => panic!("Expected Unauthorized error"),
317+
}
318+
unauthorized_mock.assert();
319+
320+
// Test bad request error
321+
let bad_request_mock = server
322+
.mock("GET", "/llms")
323+
.with_status(400)
324+
.with_header("content-type", "application/json")
325+
.with_body(r#"{"error":"Invalid provider"}"#)
326+
.create();
327+
328+
match client.list_models() {
329+
Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
330+
_ => panic!("Expected BadRequest error"),
331+
}
332+
bad_request_mock.assert();
333+
334+
// Test internal server error
335+
let internal_error_mock = server
336+
.mock("GET", "/llms")
337+
.with_status(500)
338+
.with_header("content-type", "application/json")
339+
.with_body(r#"{"error":"Internal server error occurred"}"#)
340+
.create();
341+
342+
match client.list_models() {
343+
Err(GatewayError::InternalError(msg)) => {
344+
assert_eq!(msg, "Internal server error occurred")
345+
}
346+
_ => panic!("Expected InternalError error"),
347+
}
348+
internal_error_mock.assert();
349+
}
350+
301351
#[test]
302352
fn test_provider_serialization() {
303353
let providers = vec![

0 commit comments

Comments
 (0)