Skip to content

Commit fe85dcb

Browse files
committed
chore: Update error handling to use thiserror crate and improve GatewayError structure
Signed-off-by: Eden Reich <eden.reich@gmail.com>
1 parent 0df8c87 commit fe85dcb

File tree

3 files changed

+93
-40
lines changed

3 files changed

+93
-40
lines changed

Cargo.lock

Lines changed: 23 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ categories = ["api-bindings", "web-programming::http-client"]
1414
[dependencies]
1515
reqwest = { version = "0.12.12", features = ["blocking", "json"] }
1616
serde = { version = "1.0.217", features = ["derive"] }
17-
serde_json = "1.0.137"
17+
serde_json = "1.0.138"
18+
thiserror = "2.0.11"
1819

1920
[dev-dependencies]
2021
mockito = "1.6.1"

src/lib.rs

Lines changed: 68 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,51 +3,29 @@
33
//! This crate provides a Rust client for the Inference Gateway API, allowing interaction
44
//! with various LLM providers through a unified interface.
55
6+
use core::fmt;
7+
68
use reqwest::{blocking::Client, StatusCode};
79
use serde::{Deserialize, Serialize};
8-
use std::{error::Error, fmt};
10+
use thiserror::Error;
911

1012
/// Custom error types for the Inference Gateway SDK
11-
#[derive(Debug)]
13+
#[derive(Error, Debug)]
1214
pub enum GatewayError {
13-
/// Authentication error (401)
15+
#[error("Unauthorized: {0}")]
1416
Unauthorized(String),
15-
/// Bad request error (400)
17+
18+
#[error("Bad request: {0}")]
1619
BadRequest(String),
17-
/// Internal server error (500)
18-
InternalError(String),
19-
/// Network or reqwest-related error
20-
RequestError(reqwest::Error),
21-
/// Other unexpected errors
22-
Other(Box<dyn Error + Send + Sync>),
23-
}
2420

25-
impl fmt::Display for GatewayError {
26-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27-
match self {
28-
Self::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
29-
Self::BadRequest(msg) => write!(f, "Bad request: {}", msg),
30-
Self::InternalError(msg) => write!(f, "Internal server error: {}", msg),
31-
Self::RequestError(e) => write!(f, "Request error: {}", e),
32-
Self::Other(e) => write!(f, "Other error: {}", e),
33-
}
34-
}
35-
}
21+
#[error("Internal server error: {0}")]
22+
InternalError(String),
3623

37-
impl Error for GatewayError {
38-
fn source(&self) -> Option<&(dyn Error + 'static)> {
39-
match self {
40-
Self::RequestError(e) => Some(e),
41-
Self::Other(e) => Some(e.as_ref()),
42-
_ => None,
43-
}
44-
}
45-
}
24+
#[error("Request error: {0}")]
25+
RequestError(#[from] reqwest::Error),
4626

47-
impl From<reqwest::Error> for GatewayError {
48-
fn from(err: reqwest::Error) -> Self {
49-
Self::RequestError(err)
50-
}
27+
#[error("Other error: {0}")]
28+
Other(#[from] Box<dyn std::error::Error + Send + Sync>),
5129
}
5230

5331
#[derive(Debug, Deserialize)]
@@ -181,7 +159,7 @@ pub trait InferenceGatewayAPI {
181159
) -> Result<GenerateResponse, GatewayError>;
182160

183161
/// Checks if the API is available
184-
fn health_check(&self) -> Result<bool, Box<dyn Error>>;
162+
fn health_check(&self) -> Result<bool, GatewayError>;
185163
}
186164

187165
impl InferenceGatewayClient {
@@ -308,7 +286,7 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
308286
}
309287
}
310288

311-
fn health_check(&self) -> Result<bool, Box<dyn Error>> {
289+
fn health_check(&self) -> Result<bool, GatewayError> {
312290
let url = format!("{}/health", self.base_url);
313291
let response = self.client.get(&url).send()?;
314292
Ok(response.status().is_success())
@@ -320,6 +298,59 @@ mod tests {
320298
use super::*;
321299
use mockito::{Matcher, Server};
322300

301+
#[test]
302+
fn test_provider_serialization() {
303+
let providers = vec![
304+
(Provider::Ollama, "ollama"),
305+
(Provider::Groq, "groq"),
306+
(Provider::OpenAI, "openai"),
307+
(Provider::Google, "google"),
308+
(Provider::Cloudflare, "cloudflare"),
309+
(Provider::Cohere, "cohere"),
310+
(Provider::Anthropic, "anthropic"),
311+
];
312+
313+
for (provider, expected) in providers {
314+
let json = serde_json::to_string(&provider).unwrap();
315+
assert_eq!(json, format!("\"{}\"", expected));
316+
}
317+
}
318+
319+
#[test]
320+
fn test_provider_deserialization() {
321+
let test_cases = vec![
322+
("\"ollama\"", Provider::Ollama),
323+
("\"groq\"", Provider::Groq),
324+
("\"openai\"", Provider::OpenAI),
325+
("\"google\"", Provider::Google),
326+
("\"cloudflare\"", Provider::Cloudflare),
327+
("\"cohere\"", Provider::Cohere),
328+
("\"anthropic\"", Provider::Anthropic),
329+
];
330+
331+
for (json, expected) in test_cases {
332+
let provider: Provider = serde_json::from_str(json).unwrap();
333+
assert_eq!(provider, expected);
334+
}
335+
}
336+
337+
#[test]
338+
fn test_provider_display() {
339+
let providers = vec![
340+
(Provider::Ollama, "ollama"),
341+
(Provider::Groq, "groq"),
342+
(Provider::OpenAI, "openai"),
343+
(Provider::Google, "google"),
344+
(Provider::Cloudflare, "cloudflare"),
345+
(Provider::Cohere, "cohere"),
346+
(Provider::Anthropic, "anthropic"),
347+
];
348+
349+
for (provider, expected) in providers {
350+
assert_eq!(provider.to_string(), expected);
351+
}
352+
}
353+
323354
#[test]
324355
fn test_authentication_header() {
325356
let mut server = Server::new();

0 commit comments

Comments
 (0)