Skip to content

Commit 4e0e03a

Browse files
committed
refactor: error handling
1 parent 0c074b7 commit 4e0e03a

File tree

3 files changed

+52
-63
lines changed

3 files changed

+52
-63
lines changed

crates/llm-ls/src/adaptors.rs

Lines changed: 37 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, USER_AGENT};
66
use serde::{Deserialize, Serialize};
77
use serde_json::Value;
88
use std::fmt::Display;
9-
use tower_lsp::jsonrpc;
9+
10+
use crate::error::{Error, Result};
1011

1112
fn build_tgi_body(prompt: String, params: &RequestParams) -> Value {
1213
serde_json::json!({
@@ -21,57 +22,43 @@ fn build_tgi_body(prompt: String, params: &RequestParams) -> Value {
2122
})
2223
}
2324

24-
fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
25+
fn build_tgi_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
2526
let mut headers = HeaderMap::new();
2627
let user_agent = format!("{NAME}/{VERSION}; rust/unknown; ide/{ide:?}");
27-
headers.insert(
28-
USER_AGENT,
29-
HeaderValue::from_str(&user_agent).map_err(internal_error)?,
30-
);
28+
headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?);
3129

3230
if let Some(api_token) = api_token {
3331
headers.insert(
3432
AUTHORIZATION,
35-
HeaderValue::from_str(&format!("Bearer {api_token}")).map_err(internal_error)?,
33+
HeaderValue::from_str(&format!("Bearer {api_token}"))?,
3634
);
3735
}
3836

3937
Ok(headers)
4038
}
4139

42-
fn parse_tgi_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
43-
let generations =
44-
match serde_json::from_str(text).map_err(internal_error)? {
45-
APIResponse::Generation(gen) => vec![gen],
46-
APIResponse::Generations(_) => {
47-
return Err(internal_error(
48-
"You are attempting to parse a result in the API inference format when using the `tgi` adaptor",
49-
))
50-
}
51-
APIResponse::Error(err) => return Err(internal_error(err)),
52-
};
53-
Ok(generations)
40+
fn parse_tgi_text(text: &str) -> Result<Vec<Generation>> {
41+
match serde_json::from_str(text)? {
42+
APIResponse::Generation(gen) => Ok(vec![gen]),
43+
APIResponse::Generations(_) => Err(Error::InvalidAdaptor),
44+
APIResponse::Error(err) => Err(Error::Tgi(err)),
45+
}
5446
}
5547

5648
fn build_api_body(prompt: String, params: &RequestParams) -> Value {
5749
build_tgi_body(prompt, params)
5850
}
5951

60-
fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
52+
fn build_api_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
6153
build_tgi_headers(api_token, ide)
6254
}
6355

64-
fn parse_api_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
65-
// TODO:
66-
// APIResponse::Generation(gen) => Ok(vec![gen]),
67-
// APIResponse::Generations(gens) => Ok(gens),
68-
// APIResponse::Error(err) => Err(err),
69-
let generations = match serde_json::from_str(text).map_err(internal_error)? {
70-
APIResponse::Generation(gen) => vec![gen],
71-
APIResponse::Generations(gens) => gens,
72-
APIResponse::Error(err) => return Err(internal_error(err)),
73-
};
74-
Ok(generations)
56+
fn parse_api_text(text: &str) -> Result<Vec<Generation>> {
57+
match serde_json::from_str(text)? {
58+
APIResponse::Generation(gen) => Ok(vec![gen]),
59+
APIResponse::Generations(gens) => Ok(gens),
60+
APIResponse::Error(err) => Err(Error::InferenceApi(err)),
61+
}
7562
}
7663

7764
fn build_ollama_body(prompt: String, params: &CompletionParams) -> Value {
@@ -88,7 +75,7 @@ fn build_ollama_body(prompt: String, params: &CompletionParams) -> Value {
8875
}
8976
})
9077
}
91-
fn build_ollama_headers() -> Result<HeaderMap, jsonrpc::Error> {
78+
fn build_ollama_headers() -> Result<HeaderMap> {
9279
Ok(HeaderMap::new())
9380
}
9481

@@ -112,12 +99,11 @@ enum OllamaAPIResponse {
11299
Error(APIError),
113100
}
114101

115-
fn parse_ollama_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
116-
let generations = match serde_json::from_str(text).map_err(internal_error)? {
117-
OllamaAPIResponse::Generation(gen) => vec![gen.into()],
118-
OllamaAPIResponse::Error(err) => return Err(internal_error(err)),
119-
};
120-
Ok(generations)
102+
fn parse_ollama_text(text: &str) -> Result<Vec<Generation>> {
103+
match serde_json::from_str(text)? {
104+
OllamaAPIResponse::Generation(gen) => Ok(vec![gen.into()]),
105+
OllamaAPIResponse::Error(err) => Err(Error::Ollama(err)),
106+
}
121107
}
122108

123109
fn build_openai_body(prompt: String, params: &CompletionParams) -> Value {
@@ -131,7 +117,7 @@ fn build_openai_body(prompt: String, params: &CompletionParams) -> Value {
131117
})
132118
}
133119

134-
fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap, jsonrpc::Error> {
120+
fn build_openai_headers(api_token: Option<&String>, ide: Ide) -> Result<HeaderMap> {
135121
build_api_headers(api_token, ide)
136122
}
137123

@@ -177,7 +163,7 @@ struct OpenAIErrorDetail {
177163
}
178164

179165
#[derive(Debug, Deserialize)]
180-
struct OpenAIError {
166+
pub struct OpenAIError {
181167
detail: Vec<OpenAIErrorDetail>,
182168
}
183169

@@ -200,13 +186,13 @@ enum OpenAIAPIResponse {
200186
Error(OpenAIError),
201187
}
202188

203-
fn parse_openai_text(text: &str) -> Result<Vec<Generation>, jsonrpc::Error> {
204-
match serde_json::from_str(text).map_err(internal_error) {
205-
Ok(OpenAIAPIResponse::Generation(completion)) => {
189+
fn parse_openai_text(text: &str) -> Result<Vec<Generation>> {
190+
let open_ai_response = serde_json::from_str(text)?;
191+
match open_ai_response {
192+
OpenAIAPIResponse::Generation(completion) => {
206193
Ok(completion.choices.into_iter().map(|x| x.into()).collect())
207194
}
208-
Ok(OpenAIAPIResponse::Error(err)) => Err(internal_error(err)),
209-
Err(err) => Err(internal_error(err)),
195+
OpenAIAPIResponse::Error(err) => Err(Error::OpenAI(err)),
210196
}
211197
}
212198

@@ -216,11 +202,7 @@ const OLLAMA: &str = "ollama";
216202
const OPENAI: &str = "openai";
217203
const DEFAULT_ADAPTOR: &str = HUGGING_FACE;
218204

219-
fn unknown_adaptor_error(adaptor: Option<&String>) -> jsonrpc::Error {
220-
internal_error(format!("Unknown adaptor {:?}", adaptor))
221-
}
222-
223-
pub fn adapt_body(prompt: String, params: &CompletionParams) -> Result<Value, jsonrpc::Error> {
205+
pub fn adapt_body(prompt: String, params: &CompletionParams) -> Result<Value> {
224206
match params
225207
.adaptor
226208
.as_ref()
@@ -231,30 +213,30 @@ pub fn adapt_body(prompt: String, params: &CompletionParams) -> Result<Value, js
231213
HUGGING_FACE => Ok(build_api_body(prompt, &params.request_params)),
232214
OLLAMA => Ok(build_ollama_body(prompt, params)),
233215
OPENAI => Ok(build_openai_body(prompt, params)),
234-
_ => Err(unknown_adaptor_error(params.adaptor.as_ref())),
216+
adaptor => Err(Error::UnknownAdaptor(adaptor.to_owned())),
235217
}
236218
}
237219

238220
pub fn adapt_headers(
239221
adaptor: Option<&String>,
240222
api_token: Option<&String>,
241223
ide: Ide,
242-
) -> Result<HeaderMap, jsonrpc::Error> {
224+
) -> Result<HeaderMap> {
243225
match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() {
244226
TGI => build_tgi_headers(api_token, ide),
245227
HUGGING_FACE => build_api_headers(api_token, ide),
246228
OLLAMA => build_ollama_headers(),
247229
OPENAI => build_openai_headers(api_token, ide),
248-
_ => Err(unknown_adaptor_error(adaptor)),
230+
adaptor => Err(Error::UnknownAdaptor(adaptor.to_owned())),
249231
}
250232
}
251233

252-
pub fn parse_generations(adaptor: Option<&String>, text: &str) -> jsonrpc::Result<Vec<Generation>> {
234+
pub fn parse_generations(adaptor: Option<&String>, text: &str) -> Result<Vec<Generation>> {
253235
match adaptor.unwrap_or(&DEFAULT_ADAPTOR.to_string()).as_str() {
254236
TGI => parse_tgi_text(text),
255237
HUGGING_FACE => parse_api_text(text),
256238
OLLAMA => parse_ollama_text(text),
257239
OPENAI => parse_openai_text(text),
258-
_ => Err(unknown_adaptor_error(adaptor)),
240+
adaptor => Err(Error::UnknownAdaptor(adaptor.to_owned())),
259241
}
260242
}

crates/llm-ls/src/error.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ pub fn internal_error<E: Display>(err: E) -> LspError {
1818

1919
#[derive(thiserror::Error, Debug)]
2020
pub enum Error {
21-
#[error("backend api error: {0}")]
22-
Api(#[from] APIError),
2321
#[error("arrow error: {0}")]
2422
Arrow(#[from] arrow_schema::ArrowError),
2523
#[error("candle error: {0}")]
@@ -32,12 +30,20 @@ pub enum Error {
3230
Http(#[from] reqwest::Error),
3331
#[error("io error: {0}")]
3432
Io(#[from] std::io::Error),
33+
#[error("inference api error: {0}")]
34+
InferenceApi(APIError),
35+
#[error("You are attempting to parse a result in the API inference format when using the `tgi` adaptor")]
36+
InvalidAdaptor,
3537
#[error("invalid header value: {0}")]
3638
InvalidHeaderValue(#[from] reqwest::header::InvalidHeaderValue),
3739
#[error("invalid repository id")]
3840
InvalidRepositoryId,
3941
#[error("invalid tokenizer path")]
4042
InvalidTokenizerPath,
43+
#[error("ollama error: {0}")]
44+
Ollama(APIError),
45+
#[error("openai error: {0}")]
46+
OpenAI(crate::adaptors::OpenAIError),
4147
#[error("index out of bounds: {0}")]
4248
OutOfBoundIndexing(usize),
4349
#[error("line out of bounds: {0}")]
@@ -48,10 +54,14 @@ pub enum Error {
4854
Rope(#[from] ropey::Error),
4955
#[error("serde json error: {0}")]
5056
SerdeJson(#[from] serde_json::Error),
57+
#[error("tgi error: {0}")]
58+
Tgi(APIError),
5159
#[error("tokenizer error: {0}")]
5260
Tokenizer(#[from] tokenizers::Error),
5361
#[error("tokio join error: {0}")]
5462
TokioJoin(#[from] tokio::task::JoinError),
63+
#[error("unknown adaptor: {0}")]
64+
UnknownAdaptor(String),
5565
#[error("vector db error: {0}")]
5666
VectorDb(#[from] vectordb::error::Error),
5767
}

crates/llm-ls/src/main.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ async fn request_completion(
397397
) -> Result<Vec<Generation>> {
398398
let t = Instant::now();
399399

400-
let json = adapt_body(prompt, params).map_err(internal_error)?;
400+
let json = adapt_body(prompt, params)?;
401401
let headers = adapt_headers(
402402
params.adaptor.as_ref(),
403403
params.api_token.as_ref(),
@@ -411,18 +411,15 @@ async fn request_completion(
411411
.await?;
412412

413413
let model = &params.model;
414-
let generations = parse_generations(
415-
params.adaptor.as_ref(),
416-
res.text().await.map_err(internal_error)?.as_str(),
417-
);
414+
let generations = parse_generations(params.adaptor.as_ref(), res.text().await?.as_str())?;
418415
let time = t.elapsed().as_millis();
419416
info!(
420417
model,
421418
compute_generations_ms = time,
422419
generations = serde_json::to_string(&generations)?,
423420
"{model} computed generations in {time} ms"
424421
);
425-
generations
422+
Ok(generations)
426423
}
427424

428425
fn format_generations(

0 commit comments

Comments
 (0)