@@ -6,7 +6,8 @@ use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
6
6
use serde:: { Deserialize , Serialize } ;
7
7
use serde_json:: Value ;
8
8
use std:: fmt:: Display ;
9
- use tower_lsp:: jsonrpc;
9
+
10
+ use crate :: error:: { Error , Result } ;
10
11
11
12
fn build_tgi_body ( prompt : String , params : & RequestParams ) -> Value {
12
13
serde_json:: json!( {
@@ -21,57 +22,43 @@ fn build_tgi_body(prompt: String, params: &RequestParams) -> Value {
21
22
} )
22
23
}
23
24
24
- fn build_tgi_headers ( api_token : Option < & String > , ide : Ide ) -> Result < HeaderMap , jsonrpc :: Error > {
25
+ fn build_tgi_headers ( api_token : Option < & String > , ide : Ide ) -> Result < HeaderMap > {
25
26
let mut headers = HeaderMap :: new ( ) ;
26
27
let user_agent = format ! ( "{NAME}/{VERSION}; rust/unknown; ide/{ide:?}" ) ;
27
- headers. insert (
28
- USER_AGENT ,
29
- HeaderValue :: from_str ( & user_agent) . map_err ( internal_error) ?,
30
- ) ;
28
+ headers. insert ( USER_AGENT , HeaderValue :: from_str ( & user_agent) ?) ;
31
29
32
30
if let Some ( api_token) = api_token {
33
31
headers. insert (
34
32
AUTHORIZATION ,
35
- HeaderValue :: from_str ( & format ! ( "Bearer {api_token}" ) ) . map_err ( internal_error ) ?,
33
+ HeaderValue :: from_str ( & format ! ( "Bearer {api_token}" ) ) ?,
36
34
) ;
37
35
}
38
36
39
37
Ok ( headers)
40
38
}
41
39
42
- fn parse_tgi_text ( text : & str ) -> Result < Vec < Generation > , jsonrpc:: Error > {
43
- let generations =
44
- match serde_json:: from_str ( text) . map_err ( internal_error) ? {
45
- APIResponse :: Generation ( gen) => vec ! [ gen ] ,
46
- APIResponse :: Generations ( _) => {
47
- return Err ( internal_error (
48
- "You are attempting to parse a result in the API inference format when using the `tgi` adaptor" ,
49
- ) )
50
- }
51
- APIResponse :: Error ( err) => return Err ( internal_error ( err) ) ,
52
- } ;
53
- Ok ( generations)
40
+ fn parse_tgi_text ( text : & str ) -> Result < Vec < Generation > > {
41
+ match serde_json:: from_str ( text) ? {
42
+ APIResponse :: Generation ( gen) => Ok ( vec ! [ gen ] ) ,
43
+ APIResponse :: Generations ( _) => Err ( Error :: InvalidAdaptor ) ,
44
+ APIResponse :: Error ( err) => Err ( Error :: Tgi ( err) ) ,
45
+ }
54
46
}
55
47
56
48
fn build_api_body ( prompt : String , params : & RequestParams ) -> Value {
57
49
build_tgi_body ( prompt, params)
58
50
}
59
51
60
- fn build_api_headers ( api_token : Option < & String > , ide : Ide ) -> Result < HeaderMap , jsonrpc :: Error > {
52
+ fn build_api_headers ( api_token : Option < & String > , ide : Ide ) -> Result < HeaderMap > {
61
53
build_tgi_headers ( api_token, ide)
62
54
}
63
55
64
- fn parse_api_text ( text : & str ) -> Result < Vec < Generation > , jsonrpc:: Error > {
65
- // TODO:
66
- // APIResponse::Generation(gen) => Ok(vec![gen]),
67
- // APIResponse::Generations(gens) => Ok(gens),
68
- // APIResponse::Error(err) => Err(err),
69
- let generations = match serde_json:: from_str ( text) . map_err ( internal_error) ? {
70
- APIResponse :: Generation ( gen) => vec ! [ gen ] ,
71
- APIResponse :: Generations ( gens) => gens,
72
- APIResponse :: Error ( err) => return Err ( internal_error ( err) ) ,
73
- } ;
74
- Ok ( generations)
56
+ fn parse_api_text ( text : & str ) -> Result < Vec < Generation > > {
57
+ match serde_json:: from_str ( text) ? {
58
+ APIResponse :: Generation ( gen) => Ok ( vec ! [ gen ] ) ,
59
+ APIResponse :: Generations ( gens) => Ok ( gens) ,
60
+ APIResponse :: Error ( err) => Err ( Error :: InferenceApi ( err) ) ,
61
+ }
75
62
}
76
63
77
64
fn build_ollama_body ( prompt : String , params : & CompletionParams ) -> Value {
@@ -88,7 +75,7 @@ fn build_ollama_body(prompt: String, params: &CompletionParams) -> Value {
88
75
}
89
76
} )
90
77
}
91
- fn build_ollama_headers ( ) -> Result < HeaderMap , jsonrpc :: Error > {
78
+ fn build_ollama_headers ( ) -> Result < HeaderMap > {
92
79
Ok ( HeaderMap :: new ( ) )
93
80
}
94
81
@@ -112,12 +99,11 @@ enum OllamaAPIResponse {
112
99
Error ( APIError ) ,
113
100
}
114
101
115
- fn parse_ollama_text ( text : & str ) -> Result < Vec < Generation > , jsonrpc:: Error > {
116
- let generations = match serde_json:: from_str ( text) . map_err ( internal_error) ? {
117
- OllamaAPIResponse :: Generation ( gen) => vec ! [ gen . into( ) ] ,
118
- OllamaAPIResponse :: Error ( err) => return Err ( internal_error ( err) ) ,
119
- } ;
120
- Ok ( generations)
102
+ fn parse_ollama_text ( text : & str ) -> Result < Vec < Generation > > {
103
+ match serde_json:: from_str ( text) ? {
104
+ OllamaAPIResponse :: Generation ( gen) => Ok ( vec ! [ gen . into( ) ] ) ,
105
+ OllamaAPIResponse :: Error ( err) => Err ( Error :: Ollama ( err) ) ,
106
+ }
121
107
}
122
108
123
109
fn build_openai_body ( prompt : String , params : & CompletionParams ) -> Value {
@@ -131,7 +117,7 @@ fn build_openai_body(prompt: String, params: &CompletionParams) -> Value {
131
117
} )
132
118
}
133
119
134
- fn build_openai_headers ( api_token : Option < & String > , ide : Ide ) -> Result < HeaderMap , jsonrpc :: Error > {
120
+ fn build_openai_headers ( api_token : Option < & String > , ide : Ide ) -> Result < HeaderMap > {
135
121
build_api_headers ( api_token, ide)
136
122
}
137
123
@@ -177,7 +163,7 @@ struct OpenAIErrorDetail {
177
163
}
178
164
179
165
#[ derive( Debug , Deserialize ) ]
180
- struct OpenAIError {
166
+ pub struct OpenAIError {
181
167
detail : Vec < OpenAIErrorDetail > ,
182
168
}
183
169
@@ -200,13 +186,13 @@ enum OpenAIAPIResponse {
200
186
Error ( OpenAIError ) ,
201
187
}
202
188
203
- fn parse_openai_text ( text : & str ) -> Result < Vec < Generation > , jsonrpc:: Error > {
204
- match serde_json:: from_str ( text) . map_err ( internal_error) {
205
- Ok ( OpenAIAPIResponse :: Generation ( completion) ) => {
189
+ fn parse_openai_text ( text : & str ) -> Result < Vec < Generation > > {
190
+ let open_ai_response = serde_json:: from_str ( text) ?;
191
+ match open_ai_response {
192
+ OpenAIAPIResponse :: Generation ( completion) => {
206
193
Ok ( completion. choices . into_iter ( ) . map ( |x| x. into ( ) ) . collect ( ) )
207
194
}
208
- Ok ( OpenAIAPIResponse :: Error ( err) ) => Err ( internal_error ( err) ) ,
209
- Err ( err) => Err ( internal_error ( err) ) ,
195
+ OpenAIAPIResponse :: Error ( err) => Err ( Error :: OpenAI ( err) ) ,
210
196
}
211
197
}
212
198
@@ -216,11 +202,7 @@ const OLLAMA: &str = "ollama";
216
202
const OPENAI : & str = "openai" ;
217
203
const DEFAULT_ADAPTOR : & str = HUGGING_FACE ;
218
204
219
- fn unknown_adaptor_error ( adaptor : Option < & String > ) -> jsonrpc:: Error {
220
- internal_error ( format ! ( "Unknown adaptor {:?}" , adaptor) )
221
- }
222
-
223
- pub fn adapt_body ( prompt : String , params : & CompletionParams ) -> Result < Value , jsonrpc:: Error > {
205
+ pub fn adapt_body ( prompt : String , params : & CompletionParams ) -> Result < Value > {
224
206
match params
225
207
. adaptor
226
208
. as_ref ( )
@@ -231,30 +213,30 @@ pub fn adapt_body(prompt: String, params: &CompletionParams) -> Result<Value, js
231
213
HUGGING_FACE => Ok ( build_api_body ( prompt, & params. request_params ) ) ,
232
214
OLLAMA => Ok ( build_ollama_body ( prompt, params) ) ,
233
215
OPENAI => Ok ( build_openai_body ( prompt, params) ) ,
234
- _ => Err ( unknown_adaptor_error ( params . adaptor . as_ref ( ) ) ) ,
216
+ adaptor => Err ( Error :: UnknownAdaptor ( adaptor. to_owned ( ) ) ) ,
235
217
}
236
218
}
237
219
238
220
pub fn adapt_headers (
239
221
adaptor : Option < & String > ,
240
222
api_token : Option < & String > ,
241
223
ide : Ide ,
242
- ) -> Result < HeaderMap , jsonrpc :: Error > {
224
+ ) -> Result < HeaderMap > {
243
225
match adaptor. unwrap_or ( & DEFAULT_ADAPTOR . to_string ( ) ) . as_str ( ) {
244
226
TGI => build_tgi_headers ( api_token, ide) ,
245
227
HUGGING_FACE => build_api_headers ( api_token, ide) ,
246
228
OLLAMA => build_ollama_headers ( ) ,
247
229
OPENAI => build_openai_headers ( api_token, ide) ,
248
- _ => Err ( unknown_adaptor_error ( adaptor) ) ,
230
+ adaptor => Err ( Error :: UnknownAdaptor ( adaptor. to_owned ( ) ) ) ,
249
231
}
250
232
}
251
233
252
- pub fn parse_generations ( adaptor : Option < & String > , text : & str ) -> jsonrpc :: Result < Vec < Generation > > {
234
+ pub fn parse_generations ( adaptor : Option < & String > , text : & str ) -> Result < Vec < Generation > > {
253
235
match adaptor. unwrap_or ( & DEFAULT_ADAPTOR . to_string ( ) ) . as_str ( ) {
254
236
TGI => parse_tgi_text ( text) ,
255
237
HUGGING_FACE => parse_api_text ( text) ,
256
238
OLLAMA => parse_ollama_text ( text) ,
257
239
OPENAI => parse_openai_text ( text) ,
258
- _ => Err ( unknown_adaptor_error ( adaptor) ) ,
240
+ adaptor => Err ( Error :: UnknownAdaptor ( adaptor. to_owned ( ) ) ) ,
259
241
}
260
242
}
0 commit comments