Skip to content

Commit 0715d11

Browse files
committed
chore: Add case-insensitive provider aliases and update URL generation
Signed-off-by: Eden Reich <eden.reich@gmail.com>
1 parent e9231d2 commit 0715d11

File tree

1 file changed

+45
-14
lines changed

1 file changed

+45
-14
lines changed

src/lib.rs

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,19 @@ pub struct ProviderModels {
5555
#[derive(Debug, Serialize, Deserialize, PartialEq)]
5656
#[serde(rename_all = "lowercase")]
5757
pub enum Provider {
58+
#[serde(alias = "Ollama", alias = "OLLAMA")]
5859
Ollama,
60+
#[serde(alias = "Groq", alias = "GROQ")]
5961
Groq,
62+
#[serde(alias = "OpenAI", alias = "OPENAI")]
6063
OpenAI,
64+
#[serde(alias = "Google", alias = "GOOGLE")]
6165
Google,
66+
#[serde(alias = "Cloudflare", alias = "CLOUDFLARE")]
6267
Cloudflare,
68+
#[serde(alias = "Cohere", alias = "COHERE")]
6369
Cohere,
70+
#[serde(alias = "Anthropic", alias = "ANTHROPIC")]
6471
Anthropic,
6572
}
6673

@@ -299,8 +306,7 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
299306
model: &str,
300307
messages: Vec<Message>,
301308
) -> Result<GenerateResponse, GatewayError> {
302-
let provider_str = provider.to_string();
303-
let url = format!("{}/llms/{}/generate", self.base_url, provider_str);
309+
let url = format!("{}/llms/{}/generate", self.base_url, provider);
304310
let mut request = self.client.post(&url);
305311
if let Some(token) = &self.token {
306312
request = request.bearer_auth(token);
@@ -386,18 +392,6 @@ mod tests {
386392
}
387393
}
388394

389-
#[test]
390-
fn test_provider_case_serialization() {
391-
// Test that Provider::Groq serializes to lowercase
392-
let provider = Provider::Groq;
393-
let json = serde_json::to_string(&provider).unwrap();
394-
assert_eq!(json, r#""groq""#);
395-
396-
// Test that uppercase fails to deserialize
397-
let result: Result<Provider, _> = serde_json::from_str(r#""Groq""#);
398-
assert!(result.is_err());
399-
}
400-
401395
#[test]
402396
fn test_provider_display() {
403397
let providers = vec![
@@ -711,6 +705,43 @@ mod tests {
711705
Ok(())
712706
}
713707

708+
#[tokio::test]
709+
async fn test_generate_content_case_insensitive() -> Result<(), GatewayError> {
710+
let mut server = Server::new_async().await;
711+
712+
let raw_json = r#"{
713+
"provider": "Groq",
714+
"response": {
715+
"role": "assistant",
716+
"model": "mixtral-8x7b",
717+
"content": "Hello"
718+
}
719+
}"#;
720+
721+
let mock = server
722+
.mock("POST", "/llms/groq/generate")
723+
.with_status(200)
724+
.with_header("content-type", "application/json")
725+
.with_body(raw_json)
726+
.create();
727+
728+
let client = InferenceGatewayClient::new(&server.url());
729+
let messages = vec![Message {
730+
role: MessageRole::User,
731+
content: "Hello".to_string(),
732+
}];
733+
734+
let response = client
735+
.generate_content(Provider::Groq, "mixtral-8x7b", messages)
736+
.await?;
737+
738+
assert_eq!(response.provider, Provider::Groq);
739+
assert_eq!(response.response.content, "Hello");
740+
mock.assert();
741+
742+
Ok(())
743+
}
744+
714745
#[tokio::test]
715746
async fn test_health_check() -> Result<(), GatewayError> {
716747
let mut server = Server::new_async().await;

0 commit comments

Comments
 (0)