33//! This crate provides a Rust client for the Inference Gateway API, allowing interaction
44//! with various LLM providers through a unified interface.
55
6+ use core:: fmt;
7+
68use reqwest:: { blocking:: Client , StatusCode } ;
79use serde:: { Deserialize , Serialize } ;
8- use std :: { error :: Error , fmt } ;
10+ use thiserror :: Error ;
911
1012/// Custom error types for the Inference Gateway SDK
11- #[ derive( Debug ) ]
13+ #[ derive( Error , Debug ) ]
1214pub enum GatewayError {
13- /// Authentication error (401)
15+ # [ error( "Unauthorized: {0}" ) ]
1416 Unauthorized ( String ) ,
15- /// Bad request error (400)
17+
18+ #[ error( "Bad request: {0}" ) ]
1619 BadRequest ( String ) ,
17- /// Internal server error (500)
18- InternalError ( String ) ,
19- /// Network or reqwest-related error
20- RequestError ( reqwest:: Error ) ,
21- /// Other unexpected errors
22- Other ( Box < dyn Error + Send + Sync > ) ,
23- }
2420
25- impl fmt:: Display for GatewayError {
26- fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
27- match self {
28- Self :: Unauthorized ( msg) => write ! ( f, "Unauthorized: {}" , msg) ,
29- Self :: BadRequest ( msg) => write ! ( f, "Bad request: {}" , msg) ,
30- Self :: InternalError ( msg) => write ! ( f, "Internal server error: {}" , msg) ,
31- Self :: RequestError ( e) => write ! ( f, "Request error: {}" , e) ,
32- Self :: Other ( e) => write ! ( f, "Other error: {}" , e) ,
33- }
34- }
35- }
21+ #[ error( "Internal server error: {0}" ) ]
22+ InternalError ( String ) ,
3623
37- impl Error for GatewayError {
38- fn source ( & self ) -> Option < & ( dyn Error + ' static ) > {
39- match self {
40- Self :: RequestError ( e) => Some ( e) ,
41- Self :: Other ( e) => Some ( e. as_ref ( ) ) ,
42- _ => None ,
43- }
44- }
45- }
24+ #[ error( "Request error: {0}" ) ]
25+ RequestError ( #[ from] reqwest:: Error ) ,
4626
47- impl From < reqwest:: Error > for GatewayError {
48- fn from ( err : reqwest:: Error ) -> Self {
49- Self :: RequestError ( err)
50- }
27+ #[ error( "Other error: {0}" ) ]
28+ Other ( #[ from] Box < dyn std:: error:: Error + Send + Sync > ) ,
5129}
5230
5331#[ derive( Debug , Deserialize ) ]
@@ -181,7 +159,7 @@ pub trait InferenceGatewayAPI {
181159 ) -> Result < GenerateResponse , GatewayError > ;
182160
183161 /// Checks if the API is available
184- fn health_check ( & self ) -> Result < bool , Box < dyn Error > > ;
162+ fn health_check ( & self ) -> Result < bool , GatewayError > ;
185163}
186164
187165impl InferenceGatewayClient {
@@ -308,7 +286,7 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
308286 }
309287 }
310288
311- fn health_check ( & self ) -> Result < bool , Box < dyn Error > > {
289+ fn health_check ( & self ) -> Result < bool , GatewayError > {
312290 let url = format ! ( "{}/health" , self . base_url) ;
313291 let response = self . client . get ( & url) . send ( ) ?;
314292 Ok ( response. status ( ) . is_success ( ) )
@@ -320,6 +298,59 @@ mod tests {
320298 use super :: * ;
321299 use mockito:: { Matcher , Server } ;
322300
301+ #[ test]
302+ fn test_provider_serialization ( ) {
303+ let providers = vec ! [
304+ ( Provider :: Ollama , "ollama" ) ,
305+ ( Provider :: Groq , "groq" ) ,
306+ ( Provider :: OpenAI , "openai" ) ,
307+ ( Provider :: Google , "google" ) ,
308+ ( Provider :: Cloudflare , "cloudflare" ) ,
309+ ( Provider :: Cohere , "cohere" ) ,
310+ ( Provider :: Anthropic , "anthropic" ) ,
311+ ] ;
312+
313+ for ( provider, expected) in providers {
314+ let json = serde_json:: to_string ( & provider) . unwrap ( ) ;
315+ assert_eq ! ( json, format!( "\" {}\" " , expected) ) ;
316+ }
317+ }
318+
319+ #[ test]
320+ fn test_provider_deserialization ( ) {
321+ let test_cases = vec ! [
322+ ( "\" ollama\" " , Provider :: Ollama ) ,
323+ ( "\" groq\" " , Provider :: Groq ) ,
324+ ( "\" openai\" " , Provider :: OpenAI ) ,
325+ ( "\" google\" " , Provider :: Google ) ,
326+ ( "\" cloudflare\" " , Provider :: Cloudflare ) ,
327+ ( "\" cohere\" " , Provider :: Cohere ) ,
328+ ( "\" anthropic\" " , Provider :: Anthropic ) ,
329+ ] ;
330+
331+ for ( json, expected) in test_cases {
332+ let provider: Provider = serde_json:: from_str ( json) . unwrap ( ) ;
333+ assert_eq ! ( provider, expected) ;
334+ }
335+ }
336+
337+ #[ test]
338+ fn test_provider_display ( ) {
339+ let providers = vec ! [
340+ ( Provider :: Ollama , "ollama" ) ,
341+ ( Provider :: Groq , "groq" ) ,
342+ ( Provider :: OpenAI , "openai" ) ,
343+ ( Provider :: Google , "google" ) ,
344+ ( Provider :: Cloudflare , "cloudflare" ) ,
345+ ( Provider :: Cohere , "cohere" ) ,
346+ ( Provider :: Anthropic , "anthropic" ) ,
347+ ] ;
348+
349+ for ( provider, expected) in providers {
350+ assert_eq ! ( provider. to_string( ) , expected) ;
351+ }
352+ }
353+
323354 #[ test]
324355 fn test_authentication_header ( ) {
325356 let mut server = Server :: new ( ) ;
0 commit comments