Skip to content

Commit 3da7e50

Browse files
committed
chore: Implement TryFrom for Provider just in case users want to convert the string to an enum
Signed-off-by: Eden Reich <eden.reich@gmail.com>
1 parent 9c8bbc3 commit 3da7e50

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

src/lib.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,23 @@ impl fmt::Display for Provider {
7676
}
7777
}
7878

79+
impl TryFrom<&str> for Provider {
80+
type Error = GatewayError;
81+
82+
fn try_from(s: &str) -> Result<Self, Self::Error> {
83+
match s.to_lowercase().as_str() {
84+
"ollama" => Ok(Self::Ollama),
85+
"groq" => Ok(Self::Groq),
86+
"openai" => Ok(Self::OpenAI),
87+
"google" => Ok(Self::Google),
88+
"cloudflare" => Ok(Self::Cloudflare),
89+
"cohere" => Ok(Self::Cohere),
90+
"anthropic" => Ok(Self::Anthropic),
91+
_ => Err(GatewayError::BadRequest(format!("Unknown provider: {}", s))),
92+
}
93+
}
94+
}
95+
7996
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
8097
#[serde(rename_all = "lowercase")]
8198
pub enum MessageRole {
@@ -137,12 +154,35 @@ pub struct InferenceGatewayClient {
137154
token: Option<String>,
138155
}
139156

157+
/// Implement Debug for InferenceGatewayClient
158+
impl std::fmt::Debug for InferenceGatewayClient {
159+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160+
f.debug_struct("InferenceGatewayClient")
161+
.field("base_url", &self.base_url)
162+
.field("token", &self.token.as_ref().map(|_| "*****"))
163+
.finish()
164+
}
165+
}
166+
140167
/// Core API interface for the Inference Gateway
141168
pub trait InferenceGatewayAPI {
142169
/// Lists available models from all providers
170+
///
171+
/// # Errors
172+
/// - Returns [`GatewayError::Unauthorized`] if authentication fails
173+
/// - Returns [`GatewayError::BadRequest`] if the request is malformed
174+
/// - Returns [`GatewayError::InternalError`] if the server has an error
143175
fn list_models(&self) -> Result<Vec<ProviderModels>, GatewayError>;
144176

145177
/// Lists available models by a specific provider
178+
///
179+
/// # Arguments
180+
/// * `provider` - The LLM provider to list models for
181+
///
182+
/// # Errors
183+
/// - Returns [`GatewayError::Unauthorized`] if authentication fails
184+
/// - Returns [`GatewayError::BadRequest`] if the request is malformed
185+
/// - Returns [`GatewayError::InternalError`] if the server has an error
146186
fn list_models_by_provider(&self, provider: Provider) -> Result<ProviderModels, GatewayError>;
147187

148188
/// Generates content using a specified model

0 commit comments

Comments
 (0)