Skip to content

Commit cdeb27c

Browse files
committed
test: Remove unnecessary lowercase conversion for provider string and enhance error handling in API responses
Signed-off-by: Eden Reich <eden.reich@gmail.com>
1 parent 41e9d94 commit cdeb27c

File tree

1 file changed

+195
-63
lines changed

1 file changed

+195
-63
lines changed

src/lib.rs

Lines changed: 195 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
299299
model: &str,
300300
messages: Vec<Message>,
301301
) -> Result<GenerateResponse, GatewayError> {
302-
let provider_str = provider.to_string().to_lowercase(); // force lowercase - TODO - fix the serialization
302+
let provider_str = provider.to_string();
303303
let url = format!("{}/llms/{}/generate", self.base_url, provider_str);
304304
let mut request = self.client.post(&url);
305305
if let Some(token) = &self.token {
@@ -312,23 +312,24 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
312312
};
313313

314314
let response = request.json(&request_payload).send().await?;
315+
315316
match response.status() {
316317
StatusCode::OK => Ok(response.json().await?),
317-
StatusCode::UNAUTHORIZED => {
318-
let error: ErrorResponse = response.json().await?;
319-
Err(GatewayError::Unauthorized(error.error))
320-
}
321318
StatusCode::BAD_REQUEST => {
322319
let error: ErrorResponse = response.json().await?;
323320
Err(GatewayError::BadRequest(error.error))
324321
}
322+
StatusCode::UNAUTHORIZED => {
323+
let error: ErrorResponse = response.json().await?;
324+
Err(GatewayError::Unauthorized(error.error))
325+
}
325326
StatusCode::INTERNAL_SERVER_ERROR => {
326327
let error: ErrorResponse = response.json().await?;
327328
Err(GatewayError::InternalError(error.error))
328329
}
329-
_ => Err(GatewayError::Other(Box::new(std::io::Error::new(
330+
status => Err(GatewayError::Other(Box::new(std::io::Error::new(
330331
std::io::ErrorKind::Other,
331-
format!("Unexpected status code: {}", response.status()),
332+
format!("Unexpected status code: {}", status),
332333
)))),
333334
}
334335
}
@@ -349,58 +350,6 @@ mod tests {
349350
use super::*;
350351
use mockito::{Matcher, Server};
351352

352-
#[tokio::test]
353-
async fn test_gateway_errors() -> Result<(), GatewayError> {
354-
let mut server: mockito::ServerGuard = Server::new_async().await;
355-
356-
// Test unauthorized error
357-
let unauthorized_mock = server
358-
.mock("GET", "/llms")
359-
.with_status(401)
360-
.with_header("content-type", "application/json")
361-
.with_body(r#"{"error":"Invalid token"}"#)
362-
.create();
363-
364-
let client = InferenceGatewayClient::new(&server.url());
365-
match client.list_models().await {
366-
Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
367-
_ => panic!("Expected Unauthorized error"),
368-
}
369-
unauthorized_mock.assert();
370-
371-
// Test bad request error
372-
let bad_request_mock = server
373-
.mock("GET", "/llms")
374-
.with_status(400)
375-
.with_header("content-type", "application/json")
376-
.with_body(r#"{"error":"Invalid provider"}"#)
377-
.create();
378-
379-
match client.list_models().await {
380-
Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
381-
_ => panic!("Expected BadRequest error"),
382-
}
383-
bad_request_mock.assert();
384-
385-
// Test internal server error
386-
let internal_error_mock = server
387-
.mock("GET", "/llms")
388-
.with_status(500)
389-
.with_header("content-type", "application/json")
390-
.with_body(r#"{"error":"Internal server error occurred"}"#)
391-
.create();
392-
393-
match client.list_models().await {
394-
Err(GatewayError::InternalError(msg)) => {
395-
assert_eq!(msg, "Internal server error occurred")
396-
}
397-
_ => panic!("Expected InternalError error"),
398-
}
399-
internal_error_mock.assert();
400-
401-
Ok(())
402-
}
403-
404353
#[test]
405354
fn test_provider_serialization() {
406355
let providers = vec![
@@ -437,6 +386,18 @@ mod tests {
437386
}
438387
}
439388

389+
#[test]
390+
fn test_provider_case_serialization() {
391+
// Test that Provider::Groq serializes to lowercase
392+
let provider = Provider::Groq;
393+
let json = serde_json::to_string(&provider).unwrap();
394+
assert_eq!(json, r#""groq""#);
395+
396+
// Test that uppercase fails to deserialize
397+
let result: Result<Provider, _> = serde_json::from_str(r#""Groq""#);
398+
assert!(result.is_err());
399+
}
400+
440401
#[test]
441402
fn test_provider_display() {
442403
let providers = vec![
@@ -492,11 +453,16 @@ mod tests {
492453
#[tokio::test]
493454
async fn test_unauthorized_error() -> Result<(), GatewayError> {
494455
let mut server = Server::new_async().await;
456+
457+
let raw_json_response = r#"{
458+
"error": "Invalid token"
459+
}"#;
460+
495461
let mock = server
496462
.mock("GET", "/llms")
497463
.with_status(401)
498464
.with_header("content-type", "application/json")
499-
.with_body(r#"{"error":"Invalid token"}"#)
465+
.with_body(raw_json_response)
500466
.create();
501467

502468
let client = InferenceGatewayClient::new(&server.url());
@@ -514,11 +480,21 @@ mod tests {
514480
#[tokio::test]
515481
async fn test_list_models() -> Result<(), GatewayError> {
516482
let mut server = Server::new_async().await;
483+
484+
let raw_response_json = r#"[
485+
{
486+
"provider": "ollama",
487+
"models": [
488+
{"name": "llama2"}
489+
]
490+
}
491+
]"#;
492+
517493
let mock = server
518494
.mock("GET", "/llms")
519495
.with_status(200)
520496
.with_header("content-type", "application/json")
521-
.with_body(r#"[{"provider":"ollama","models":[{"name":"llama2"}]}]"#)
497+
.with_body(raw_response_json)
522498
.create();
523499

524500
let client = InferenceGatewayClient::new(&server.url());
@@ -534,11 +510,19 @@ mod tests {
534510
#[tokio::test]
535511
async fn test_list_models_by_provider() -> Result<(), GatewayError> {
536512
let mut server = Server::new_async().await;
513+
514+
let raw_json_response = r#"{
515+
"provider":"ollama",
516+
"models": [{
517+
"name": "llama2"
518+
}]
519+
}"#;
520+
537521
let mock = server
538522
.mock("GET", "/llms/ollama")
539523
.with_status(200)
540524
.with_header("content-type", "application/json")
541-
.with_body(r#"{"provider":"ollama","models":[{"name":"llama2"}]}"#)
525+
.with_body(raw_json_response)
542526
.create();
543527

544528
let client = InferenceGatewayClient::new(&server.url());
@@ -554,11 +538,21 @@ mod tests {
554538
#[tokio::test]
555539
async fn test_generate_content() -> Result<(), GatewayError> {
556540
let mut server = Server::new_async().await;
541+
542+
let raw_json_response = r#"{
543+
"provider":"ollama",
544+
"response":{
545+
"role":"assistant",
546+
"model":"llama2",
547+
"content":"Hellloooo"
548+
}
549+
}"#;
550+
557551
let mock = server
558552
.mock("POST", "/llms/ollama/generate")
559553
.with_status(200)
560554
.with_header("content-type", "application/json")
561-
.with_body(r#"{"provider":"ollama","response":{"role":"assistant","model":"llama2","content":"Hellloooo"}}"#)
555+
.with_body(raw_json_response)
562556
.create();
563557

564558
let client = InferenceGatewayClient::new(&server.url());
@@ -579,6 +573,144 @@ mod tests {
579573
Ok(())
580574
}
581575

576+
#[tokio::test]
577+
async fn test_generate_content_serialization() -> Result<(), GatewayError> {
578+
let mut server = Server::new_async().await;
579+
580+
// Raw JSON response from API for debugging
581+
let raw_json = r#"{
582+
"provider": "groq",
583+
"response": {
584+
"role": "assistant",
585+
"model": "mixtral-8x7b",
586+
"content": "Hello"
587+
}
588+
}"#;
589+
590+
// Create mock with exact JSON structure
591+
let mock = server
592+
.mock("POST", "/llms/groq/generate")
593+
.with_status(200)
594+
.with_header("content-type", "application/json")
595+
.with_body(raw_json)
596+
.create();
597+
598+
let client = InferenceGatewayClient::new(&server.url());
599+
600+
// Test direct JSON deserialization first
601+
let direct_parse: Result<GenerateResponse, _> = serde_json::from_str(raw_json);
602+
assert!(
603+
direct_parse.is_ok(),
604+
"Direct JSON parse failed: {:?}",
605+
direct_parse.err()
606+
);
607+
608+
// Test through client
609+
let messages = vec![Message {
610+
role: MessageRole::User,
611+
content: "Hello".to_string(),
612+
}];
613+
614+
let response = client
615+
.generate_content(Provider::Groq, "mixtral-8x7b", messages)
616+
.await?;
617+
618+
// Verify structure matches
619+
assert_eq!(response.provider, Provider::Groq);
620+
assert_eq!(response.response.role, MessageRole::Assistant);
621+
assert_eq!(response.response.model, "mixtral-8x7b");
622+
assert_eq!(response.response.content, "Hello");
623+
624+
mock.assert();
625+
Ok(())
626+
}
627+
628+
#[tokio::test]
629+
async fn test_generate_content_error_response() -> Result<(), GatewayError> {
630+
let mut server = Server::new_async().await;
631+
632+
let raw_json_response = r#"{
633+
"error":"Invalid request"
634+
}"#;
635+
636+
let mock = server
637+
.mock("POST", "/llms/groq/generate")
638+
.with_status(400)
639+
.with_header("content-type", "application/json")
640+
.with_body(raw_json_response)
641+
.create();
642+
643+
let client = InferenceGatewayClient::new(&server.url());
644+
let messages = vec![Message {
645+
role: MessageRole::User,
646+
content: "Hello".to_string(),
647+
}];
648+
let error = client
649+
.generate_content(Provider::Groq, "mixtral-8x7b", messages)
650+
.await
651+
.unwrap_err();
652+
653+
assert!(matches!(error, GatewayError::BadRequest(_)));
654+
if let GatewayError::BadRequest(msg) = error {
655+
assert_eq!(msg, "Invalid request");
656+
}
657+
mock.assert();
658+
659+
Ok(())
660+
}
661+
662+
#[tokio::test]
663+
async fn test_gateway_errors() -> Result<(), GatewayError> {
664+
let mut server: mockito::ServerGuard = Server::new_async().await;
665+
666+
// Test unauthorized error
667+
let unauthorized_mock = server
668+
.mock("GET", "/llms")
669+
.with_status(401)
670+
.with_header("content-type", "application/json")
671+
.with_body(r#"{"error":"Invalid token"}"#)
672+
.create();
673+
674+
let client = InferenceGatewayClient::new(&server.url());
675+
match client.list_models().await {
676+
Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
677+
_ => panic!("Expected Unauthorized error"),
678+
}
679+
unauthorized_mock.assert();
680+
681+
// Test bad request error
682+
let bad_request_mock = server
683+
.mock("GET", "/llms")
684+
.with_status(400)
685+
.with_header("content-type", "application/json")
686+
.with_body(r#"{"error":"Invalid provider"}"#)
687+
.create();
688+
689+
match client.list_models().await {
690+
Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
691+
_ => panic!("Expected BadRequest error"),
692+
}
693+
bad_request_mock.assert();
694+
695+
// Test internal server error
696+
let internal_error_mock = server
697+
.mock("GET", "/llms")
698+
.with_status(500)
699+
.with_header("content-type", "application/json")
700+
.with_body(r#"{"error":"Internal server error occurred"}"#)
701+
.create();
702+
703+
match client.list_models().await {
704+
Err(GatewayError::InternalError(msg)) => {
705+
assert_eq!(msg, "Internal server error occurred")
706+
}
707+
_ => panic!("Expected InternalError error"),
708+
}
709+
internal_error_mock.assert();
710+
711+
Ok(())
712+
}
713+
582714
#[tokio::test]
583715
async fn test_health_check() -> Result<(), GatewayError> {
584716
let mut server = Server::new_async().await;

0 commit comments

Comments
 (0)