@@ -299,7 +299,7 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
299299 model : & str ,
300300 messages : Vec < Message > ,
301301 ) -> Result < GenerateResponse , GatewayError > {
302- let provider_str = provider. to_string ( ) . to_lowercase ( ) ; // force lowercase - TODO - fix the serialization
302+ let provider_str = provider. to_string ( ) ;
303303 let url = format ! ( "{}/llms/{}/generate" , self . base_url, provider_str) ;
304304 let mut request = self . client . post ( & url) ;
305305 if let Some ( token) = & self . token {
@@ -312,23 +312,24 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
312312 } ;
313313
314314 let response = request. json ( & request_payload) . send ( ) . await ?;
315+
315316 match response. status ( ) {
316317 StatusCode :: OK => Ok ( response. json ( ) . await ?) ,
317- StatusCode :: UNAUTHORIZED => {
318- let error: ErrorResponse = response. json ( ) . await ?;
319- Err ( GatewayError :: Unauthorized ( error. error ) )
320- }
321318 StatusCode :: BAD_REQUEST => {
322319 let error: ErrorResponse = response. json ( ) . await ?;
323320 Err ( GatewayError :: BadRequest ( error. error ) )
324321 }
322+ StatusCode :: UNAUTHORIZED => {
323+ let error: ErrorResponse = response. json ( ) . await ?;
324+ Err ( GatewayError :: Unauthorized ( error. error ) )
325+ }
325326 StatusCode :: INTERNAL_SERVER_ERROR => {
326327 let error: ErrorResponse = response. json ( ) . await ?;
327328 Err ( GatewayError :: InternalError ( error. error ) )
328329 }
329- _ => Err ( GatewayError :: Other ( Box :: new ( std:: io:: Error :: new (
330+ status => Err ( GatewayError :: Other ( Box :: new ( std:: io:: Error :: new (
330331 std:: io:: ErrorKind :: Other ,
331- format ! ( "Unexpected status code: {}" , response . status( ) ) ,
332+ format ! ( "Unexpected status code: {}" , status) ,
332333 ) ) ) ) ,
333334 }
334335 }
@@ -349,58 +350,6 @@ mod tests {
349350 use super :: * ;
350351 use mockito:: { Matcher , Server } ;
351352
352- #[ tokio:: test]
353- async fn test_gateway_errors ( ) -> Result < ( ) , GatewayError > {
354- let mut server: mockito:: ServerGuard = Server :: new_async ( ) . await ;
355-
356- // Test unauthorized error
357- let unauthorized_mock = server
358- . mock ( "GET" , "/llms" )
359- . with_status ( 401 )
360- . with_header ( "content-type" , "application/json" )
361- . with_body ( r#"{"error":"Invalid token"}"# )
362- . create ( ) ;
363-
364- let client = InferenceGatewayClient :: new ( & server. url ( ) ) ;
365- match client. list_models ( ) . await {
366- Err ( GatewayError :: Unauthorized ( msg) ) => assert_eq ! ( msg, "Invalid token" ) ,
367- _ => panic ! ( "Expected Unauthorized error" ) ,
368- }
369- unauthorized_mock. assert ( ) ;
370-
371- // Test bad request error
372- let bad_request_mock = server
373- . mock ( "GET" , "/llms" )
374- . with_status ( 400 )
375- . with_header ( "content-type" , "application/json" )
376- . with_body ( r#"{"error":"Invalid provider"}"# )
377- . create ( ) ;
378-
379- match client. list_models ( ) . await {
380- Err ( GatewayError :: BadRequest ( msg) ) => assert_eq ! ( msg, "Invalid provider" ) ,
381- _ => panic ! ( "Expected BadRequest error" ) ,
382- }
383- bad_request_mock. assert ( ) ;
384-
385- // Test internal server error
386- let internal_error_mock = server
387- . mock ( "GET" , "/llms" )
388- . with_status ( 500 )
389- . with_header ( "content-type" , "application/json" )
390- . with_body ( r#"{"error":"Internal server error occurred"}"# )
391- . create ( ) ;
392-
393- match client. list_models ( ) . await {
394- Err ( GatewayError :: InternalError ( msg) ) => {
395- assert_eq ! ( msg, "Internal server error occurred" )
396- }
397- _ => panic ! ( "Expected InternalError error" ) ,
398- }
399- internal_error_mock. assert ( ) ;
400-
401- Ok ( ( ) )
402- }
403-
404353 #[ test]
405354 fn test_provider_serialization ( ) {
406355 let providers = vec ! [
@@ -437,6 +386,18 @@ mod tests {
437386 }
438387 }
439388
389+ #[ test]
390+ fn test_provider_case_serialization ( ) {
391+ // Test that Provider::Groq serializes to lowercase
392+ let provider = Provider :: Groq ;
393+ let json = serde_json:: to_string ( & provider) . unwrap ( ) ;
394+ assert_eq ! ( json, r#""groq""# ) ;
395+
396+ // Test that uppercase fails to deserialize
397+ let result: Result < Provider , _ > = serde_json:: from_str ( r#""Groq""# ) ;
398+ assert ! ( result. is_err( ) ) ;
399+ }
400+
440401 #[ test]
441402 fn test_provider_display ( ) {
442403 let providers = vec ! [
@@ -492,11 +453,16 @@ mod tests {
492453 #[ tokio:: test]
493454 async fn test_unauthorized_error ( ) -> Result < ( ) , GatewayError > {
494455 let mut server = Server :: new_async ( ) . await ;
456+
457+ let raw_json_response = r#"{
458+ "error": "Invalid token"
459+ }"# ;
460+
495461 let mock = server
496462 . mock ( "GET" , "/llms" )
497463 . with_status ( 401 )
498464 . with_header ( "content-type" , "application/json" )
499- . with_body ( r#"{"error":"Invalid token"}"# )
465+ . with_body ( raw_json_response )
500466 . create ( ) ;
501467
502468 let client = InferenceGatewayClient :: new ( & server. url ( ) ) ;
@@ -514,11 +480,21 @@ mod tests {
514480 #[ tokio:: test]
515481 async fn test_list_models ( ) -> Result < ( ) , GatewayError > {
516482 let mut server = Server :: new_async ( ) . await ;
483+
484+ let raw_response_json = r#"[
485+ {
486+ "provider": "ollama",
487+ "models": [
488+ {"name": "llama2"}
489+ ]
490+ }
491+ ]"# ;
492+
517493 let mock = server
518494 . mock ( "GET" , "/llms" )
519495 . with_status ( 200 )
520496 . with_header ( "content-type" , "application/json" )
521- . with_body ( r#"[{"provider":"ollama","models":[{"name":"llama2"}]}]"# )
497+ . with_body ( raw_response_json )
522498 . create ( ) ;
523499
524500 let client = InferenceGatewayClient :: new ( & server. url ( ) ) ;
@@ -534,11 +510,19 @@ mod tests {
534510 #[ tokio:: test]
535511 async fn test_list_models_by_provider ( ) -> Result < ( ) , GatewayError > {
536512 let mut server = Server :: new_async ( ) . await ;
513+
514+ let raw_json_response = r#"{
515+ "provider":"ollama",
516+ "models": [{
517+ "name": "llama2"
518+ }]
519+ }"# ;
520+
537521 let mock = server
538522 . mock ( "GET" , "/llms/ollama" )
539523 . with_status ( 200 )
540524 . with_header ( "content-type" , "application/json" )
541- . with_body ( r#"{"provider":"ollama","models":[{"name":"llama2"}]}"# )
525+ . with_body ( raw_json_response )
542526 . create ( ) ;
543527
544528 let client = InferenceGatewayClient :: new ( & server. url ( ) ) ;
@@ -554,11 +538,21 @@ mod tests {
554538 #[ tokio:: test]
555539 async fn test_generate_content ( ) -> Result < ( ) , GatewayError > {
556540 let mut server = Server :: new_async ( ) . await ;
541+
542+ let raw_json_response = r#"{
543+ "provider":"ollama",
544+ "response":{
545+ "role":"assistant",
546+ "model":"llama2",
547+ "content":"Hellloooo"
548+ }
549+ }"# ;
550+
557551 let mock = server
558552 . mock ( "POST" , "/llms/ollama/generate" )
559553 . with_status ( 200 )
560554 . with_header ( "content-type" , "application/json" )
561- . with_body ( r#"{"provider":"ollama","response":{"role":"assistant","model":"llama2","content":"Hellloooo"}}"# )
555+ . with_body ( raw_json_response )
562556 . create ( ) ;
563557
564558 let client = InferenceGatewayClient :: new ( & server. url ( ) ) ;
@@ -579,6 +573,144 @@ mod tests {
579573 Ok ( ( ) )
580574 }
581575
576+ #[ tokio:: test]
577+ async fn test_generate_content_serialization ( ) -> Result < ( ) , GatewayError > {
578+ let mut server = Server :: new_async ( ) . await ;
579+
580+ // Raw JSON response from API for debugging
581+ let raw_json = r#"{
582+ "provider": "groq",
583+ "response": {
584+ "role": "assistant",
585+ "model": "mixtral-8x7b",
586+ "content": "Hello"
587+ }
588+ }"# ;
589+
590+ // Create mock with exact JSON structure
591+ let mock = server
592+ . mock ( "POST" , "/llms/groq/generate" )
593+ . with_status ( 200 )
594+ . with_header ( "content-type" , "application/json" )
595+ . with_body ( raw_json)
596+ . create ( ) ;
597+
598+ let client = InferenceGatewayClient :: new ( & server. url ( ) ) ;
599+
600+ // Test direct JSON deserialization first
601+ let direct_parse: Result < GenerateResponse , _ > = serde_json:: from_str ( raw_json) ;
602+ assert ! (
603+ direct_parse. is_ok( ) ,
604+ "Direct JSON parse failed: {:?}" ,
605+ direct_parse. err( )
606+ ) ;
607+
608+ // Test through client
609+ let messages = vec ! [ Message {
610+ role: MessageRole :: User ,
611+ content: "Hello" . to_string( ) ,
612+ } ] ;
613+
614+ let response = client
615+ . generate_content ( Provider :: Groq , "mixtral-8x7b" , messages)
616+ . await ?;
617+
618+ // Verify structure matches
619+ assert_eq ! ( response. provider, Provider :: Groq ) ;
620+ assert_eq ! ( response. response. role, MessageRole :: Assistant ) ;
621+ assert_eq ! ( response. response. model, "mixtral-8x7b" ) ;
622+ assert_eq ! ( response. response. content, "Hello" ) ;
623+
624+ mock. assert ( ) ;
625+ Ok ( ( ) )
626+ }
627+
628+ #[ tokio:: test]
629+ async fn test_generate_content_error_response ( ) -> Result < ( ) , GatewayError > {
630+ let mut server = Server :: new_async ( ) . await ;
631+
632+ let raw_json_response = r#"{
633+ "error":"Invalid request"
634+ }"# ;
635+
636+ let mock = server
637+ . mock ( "POST" , "/llms/groq/generate" )
638+ . with_status ( 400 )
639+ . with_header ( "content-type" , "application/json" )
640+ . with_body ( raw_json_response)
641+ . create ( ) ;
642+
643+ let client = InferenceGatewayClient :: new ( & server. url ( ) ) ;
644+ let messages = vec ! [ Message {
645+ role: MessageRole :: User ,
646+ content: "Hello" . to_string( ) ,
647+ } ] ;
648+ let error = client
649+ . generate_content ( Provider :: Groq , "mixtral-8x7b" , messages)
650+ . await
651+ . unwrap_err ( ) ;
652+
653+ assert ! ( matches!( error, GatewayError :: BadRequest ( _) ) ) ;
654+ if let GatewayError :: BadRequest ( msg) = error {
655+ assert_eq ! ( msg, "Invalid request" ) ;
656+ }
657+ mock. assert ( ) ;
658+
659+ Ok ( ( ) )
660+ }
661+
662+ #[ tokio:: test]
663+ async fn test_gateway_errors ( ) -> Result < ( ) , GatewayError > {
664+ let mut server: mockito:: ServerGuard = Server :: new_async ( ) . await ;
665+
666+ // Test unauthorized error
667+ let unauthorized_mock = server
668+ . mock ( "GET" , "/llms" )
669+ . with_status ( 401 )
670+ . with_header ( "content-type" , "application/json" )
671+ . with_body ( r#"{"error":"Invalid token"}"# )
672+ . create ( ) ;
673+
674+ let client = InferenceGatewayClient :: new ( & server. url ( ) ) ;
675+ match client. list_models ( ) . await {
676+ Err ( GatewayError :: Unauthorized ( msg) ) => assert_eq ! ( msg, "Invalid token" ) ,
677+ _ => panic ! ( "Expected Unauthorized error" ) ,
678+ }
679+ unauthorized_mock. assert ( ) ;
680+
681+ // Test bad request error
682+ let bad_request_mock = server
683+ . mock ( "GET" , "/llms" )
684+ . with_status ( 400 )
685+ . with_header ( "content-type" , "application/json" )
686+ . with_body ( r#"{"error":"Invalid provider"}"# )
687+ . create ( ) ;
688+
689+ match client. list_models ( ) . await {
690+ Err ( GatewayError :: BadRequest ( msg) ) => assert_eq ! ( msg, "Invalid provider" ) ,
691+ _ => panic ! ( "Expected BadRequest error" ) ,
692+ }
693+ bad_request_mock. assert ( ) ;
694+
695+ // Test internal server error
696+ let internal_error_mock = server
697+ . mock ( "GET" , "/llms" )
698+ . with_status ( 500 )
699+ . with_header ( "content-type" , "application/json" )
700+ . with_body ( r#"{"error":"Internal server error occurred"}"# )
701+ . create ( ) ;
702+
703+ match client. list_models ( ) . await {
704+ Err ( GatewayError :: InternalError ( msg) ) => {
705+ assert_eq ! ( msg, "Internal server error occurred" )
706+ }
707+ _ => panic ! ( "Expected InternalError error" ) ,
708+ }
709+ internal_error_mock. assert ( ) ;
710+
711+ Ok ( ( ) )
712+ }
713+
582714 #[ tokio:: test]
583715 async fn test_health_check ( ) -> Result < ( ) , GatewayError > {
584716 let mut server = Server :: new_async ( ) . await ;
0 commit comments