@@ -55,12 +55,19 @@ pub struct ProviderModels {
5555#[ derive( Debug , Serialize , Deserialize , PartialEq ) ]
5656#[ serde( rename_all = "lowercase" ) ]
5757pub enum Provider {
58+ #[ serde( alias = "Ollama" , alias = "OLLAMA" ) ]
5859 Ollama ,
60+ #[ serde( alias = "Groq" , alias = "GROQ" ) ]
5961 Groq ,
62+ #[ serde( alias = "OpenAI" , alias = "OPENAI" ) ]
6063 OpenAI ,
64+ #[ serde( alias = "Google" , alias = "GOOGLE" ) ]
6165 Google ,
66+ #[ serde( alias = "Cloudflare" , alias = "CLOUDFLARE" ) ]
6267 Cloudflare ,
68+ #[ serde( alias = "Cohere" , alias = "COHERE" ) ]
6369 Cohere ,
70+ #[ serde( alias = "Anthropic" , alias = "ANTHROPIC" ) ]
6471 Anthropic ,
6572}
6673
@@ -299,8 +306,7 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
299306 model : & str ,
300307 messages : Vec < Message > ,
301308 ) -> Result < GenerateResponse , GatewayError > {
302- let provider_str = provider. to_string ( ) ;
303- let url = format ! ( "{}/llms/{}/generate" , self . base_url, provider_str) ;
309+ let url = format ! ( "{}/llms/{}/generate" , self . base_url, provider) ;
304310 let mut request = self . client . post ( & url) ;
305311 if let Some ( token) = & self . token {
306312 request = request. bearer_auth ( token) ;
@@ -386,18 +392,6 @@ mod tests {
386392 }
387393 }
388394
389- #[ test]
390- fn test_provider_case_serialization ( ) {
391- // Test that Provider::Groq serializes to lowercase
392- let provider = Provider :: Groq ;
393- let json = serde_json:: to_string ( & provider) . unwrap ( ) ;
394- assert_eq ! ( json, r#""groq""# ) ;
395-
396- // Test that uppercase fails to deserialize
397- let result: Result < Provider , _ > = serde_json:: from_str ( r#""Groq""# ) ;
398- assert ! ( result. is_err( ) ) ;
399- }
400-
401395 #[ test]
402396 fn test_provider_display ( ) {
403397 let providers = vec ! [
@@ -711,6 +705,43 @@ mod tests {
711705 Ok ( ( ) )
712706 }
713707
708+ #[ tokio:: test]
709+ async fn test_generate_content_case_insensitive ( ) -> Result < ( ) , GatewayError > {
710+ let mut server = Server :: new_async ( ) . await ;
711+
712+ let raw_json = r#"{
713+ "provider": "Groq",
714+ "response": {
715+ "role": "assistant",
716+ "model": "mixtral-8x7b",
717+ "content": "Hello"
718+ }
719+ }"# ;
720+
721+ let mock = server
722+ . mock ( "POST" , "/llms/groq/generate" )
723+ . with_status ( 200 )
724+ . with_header ( "content-type" , "application/json" )
725+ . with_body ( raw_json)
726+ . create ( ) ;
727+
728+ let client = InferenceGatewayClient :: new ( & server. url ( ) ) ;
729+ let messages = vec ! [ Message {
730+ role: MessageRole :: User ,
731+ content: "Hello" . to_string( ) ,
732+ } ] ;
733+
734+ let response = client
735+ . generate_content ( Provider :: Groq , "mixtral-8x7b" , messages)
736+ . await ?;
737+
738+ assert_eq ! ( response. provider, Provider :: Groq ) ;
739+ assert_eq ! ( response. response. content, "Hello" ) ;
740+ mock. assert ( ) ;
741+
742+ Ok ( ( ) )
743+ }
744+
714745 #[ tokio:: test]
715746 async fn test_health_check ( ) -> Result < ( ) , GatewayError > {
716747 let mut server = Server :: new_async ( ) . await ;
0 commit comments