Skip to content

Commit 5216ab2

Browse files
authored
fix: return auth errors (#451)
1 parent a11c4de commit 5216ab2

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
use std::sync::Arc;
1+
use std::{borrow::Cow, sync::Arc};
22

33
use futures::{StreamExt, stream::BoxStream};
4+
use http::header::WWW_AUTHENTICATE;
45
use reqwest::header::ACCEPT;
56
use sse_stream::{Sse, SseStream};
67

@@ -101,7 +102,23 @@ impl StreamableHttpClient for reqwest::Client {
101102
if let Some(session_id) = session_id {
102103
request = request.header(HEADER_SESSION_ID, session_id.as_ref());
103104
}
104-
let response = request.json(&message).send().await?.error_for_status()?;
105+
let response = request.json(&message).send().await?;
106+
if response.status() == reqwest::StatusCode::UNAUTHORIZED {
107+
if let Some(header) = response.headers().get(WWW_AUTHENTICATE) {
108+
let header = header
109+
.to_str()
110+
.map_err(|_| {
111+
StreamableHttpError::UnexpectedServerResponse(Cow::from(
112+
"invalid www-authenticate header value",
113+
))
114+
})?
115+
.to_string();
116+
return Err(StreamableHttpError::AuthRequired(AuthRequiredError {
117+
www_authenticate_header: header,
118+
}));
119+
}
120+
}
121+
let response = response.error_for_status()?;
105122
if response.status() == reqwest::StatusCode::ACCEPTED {
106123
return Ok(StreamableHttpPostResponse::Accepted);
107124
}

crates/rmcp/src/transport/streamable_http_client.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ use crate::{
1818

1919
type BoxedSseStream = BoxStream<'static, Result<Sse, SseError>>;
2020

21+
#[derive(Debug)]
22+
pub struct AuthRequiredError {
23+
pub www_authenticate_header: String,
24+
}
25+
2126
#[derive(Error, Debug)]
2227
pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
2328
#[error("SSE error: {0}")]
@@ -48,6 +53,8 @@ pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
4853
#[cfg_attr(docsrs, doc(cfg(feature = "auth")))]
4954
#[error("Auth error: {0}")]
5055
Auth(#[from] crate::transport::auth::AuthError),
56+
#[error("Auth required")]
57+
AuthRequired(AuthRequiredError),
5158
}
5259

5360
#[derive(Debug, Clone, Error)]
@@ -274,8 +281,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
274281
responder,
275282
message: initialize_request,
276283
} = context.recv_from_handler().await?;
277-
let _ = responder.send(Ok(()));
278-
let (message, session_id) = self
284+
let (message, session_id) = match self
279285
.client
280286
.post_message(
281287
config.uri.clone(),
@@ -284,12 +290,22 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
284290
self.config.auth_header,
285291
)
286292
.await
287-
.map_err(WorkerQuitReason::fatal_context("send initialize request"))?
288-
.expect_initialized::<C::Error>()
289-
.await
290-
.map_err(WorkerQuitReason::fatal_context(
291-
"process initialize response",
292-
))?;
293+
{
294+
Ok(res) => {
295+
let _ = responder.send(Ok(()));
296+
res.expect_initialized::<C::Error>().await.map_err(
297+
WorkerQuitReason::fatal_context("process initialize response"),
298+
)?
299+
}
300+
Err(err) => {
301+
let msg = format!("{:?}", err);
302+
let _ = responder.send(Err(err));
303+
return Err(WorkerQuitReason::fatal(
304+
StreamableHttpError::TransportChannelClosed,
305+
msg,
306+
));
307+
}
308+
};
293309
let session_id: Option<Arc<str>> = if let Some(session_id) = session_id {
294310
Some(session_id.into())
295311
} else {

0 commit comments

Comments
 (0)