diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 65a75f550..ff0cb3cc7 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -429,6 +429,10 @@ where mas_router::OAuth2AuthorizationEndpoint::route(), get(self::oauth2::authorization::get), ) + .route( + mas_router::OAuth2EndSession::route(), + get(self::oauth2::end_session::get), + ) .route( mas_router::Consent::route(), get(self::oauth2::authorization::consent::get) diff --git a/crates/handlers/src/oauth2/discovery.rs b/crates/handlers/src/oauth2/discovery.rs index 61dfc1ba5..5275697aa 100644 --- a/crates/handlers/src/oauth2/discovery.rs +++ b/crates/handlers/src/oauth2/discovery.rs @@ -65,6 +65,7 @@ pub(crate) async fn get( let revocation_endpoint = Some(url_builder.oauth_revocation_endpoint()); let userinfo_endpoint = Some(url_builder.oidc_userinfo_endpoint()); let registration_endpoint = Some(url_builder.oauth_registration_endpoint()); + let end_session_endpoint = Some(url_builder.oauth_end_session_endpoint()); let scopes_supported = Some(vec![scope::OPENID.to_string(), scope::EMAIL.to_string()]); @@ -172,6 +173,7 @@ pub(crate) async fn get( request_uri_parameter_supported, prompt_values_supported, device_authorization_endpoint, + end_session_endpoint, ..ProviderMetadata::default() }; diff --git a/crates/handlers/src/oauth2/end_session.rs b/crates/handlers/src/oauth2/end_session.rs new file mode 100644 index 000000000..e1c6d698d --- /dev/null +++ b/crates/handlers/src/oauth2/end_session.rs @@ -0,0 +1,201 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial +// Please see LICENSE files in the repository root for full details. +use axum::{ + Json, + extract::State, + response::{IntoResponse, Redirect, Response}, +}; +use axum_extra::extract::Query; +use hyper::StatusCode; +use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, record_error}; +use mas_data_model::{BoxClock, BoxRng, Clock}; +use mas_keystore::Keystore; +use mas_oidc_client::{ + error::IdTokenError, + requests::jose::{JwtVerificationData, verify_id_token}, +}; +use mas_router::UrlBuilder; +use mas_storage::{ + BoxRepository, RepositoryAccess, + queue::{QueueJobRepositoryExt as _, SyncDevicesJob}, + user::BrowserSessionRepository, +}; +use oauth2_types::errors::{ClientError, ClientErrorCode}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use crate::{BoundActivityTracker, impl_from_error_for_route}; + +#[derive(Debug, Deserialize, Serialize)] +pub(crate) struct EndSessionParam { + id_token_hint: String, + post_logout_redirect_uri: String, +} + +#[derive(Debug, Error)] +pub(crate) enum RouteError { + #[error(transparent)] + Internal(Box), + + #[error("bad request")] + BadRequest, + + #[error("client not found")] + ClientNotFound, + + #[error("client is unauthorized")] + UnauthorizedClient, + + #[error("unknown token")] + UnknownToken, +} + +impl_from_error_for_route!(mas_storage::RepositoryError); + +impl IntoResponse for RouteError { + fn into_response(self) -> Response { + let sentry_event_id = record_error!(self, Self::Internal(_)); + let response = match self { + Self::Internal(_) => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ClientError::from(ClientErrorCode::ServerError)), + ) + .into_response(), + + Self::BadRequest => ( + StatusCode::BAD_REQUEST, + Json(ClientError::from(ClientErrorCode::InvalidRequest)), + ) + .into_response(), + + Self::ClientNotFound => ( + StatusCode::UNAUTHORIZED, + Json(ClientError::from(ClientErrorCode::InvalidClient)), + ) + .into_response(), + + Self::UnauthorizedClient => ( + StatusCode::UNAUTHORIZED, + Json(ClientError::from(ClientErrorCode::UnauthorizedClient)), + ) + .into_response(), + + // If the token is unknown, we still return a 200 OK response. + Self::UnknownToken => StatusCode::OK.into_response(), + }; + + (sentry_event_id, response).into_response() + } +} + +impl From for RouteError { + fn from(_e: IdTokenError) -> Self { + Self::UnknownToken + } +} + +#[tracing::instrument(name = "handlers.oauth2.end_session.get", skip_all)] +pub(crate) async fn get( + mut rng: BoxRng, + clock: BoxClock, + State(key_store): State, + State(url_builder): State, + mut repo: BoxRepository, + activity_tracker: BoundActivityTracker, + Query(params): Query, + cookie_jar: CookieJar, +) -> Result { + let (session_info, cookie_jar) = cookie_jar.session_info(); + + let browser_session_id = session_info + .current_session_id() + .ok_or(RouteError::BadRequest)?; + + let browser_session = repo + .browser_session() + .lookup(browser_session_id) + .await? + .ok_or(RouteError::BadRequest)?; + + let oauth_session = repo + .oauth2_session() + .find_by_browser_session(browser_session.id) + .await? + .ok_or(RouteError::BadRequest)?; + + let client = repo + .oauth2_client() + .lookup(oauth_session.client_id) + .await? + .filter(|client| client.id_token_signed_response_alg.is_some()) + .ok_or(RouteError::ClientNotFound)?; + + let jwks = key_store.public_jwks(); + let issuer: String = url_builder.oidc_issuer().into(); + + let id_token_verification_data = JwtVerificationData { + issuer: Some(&issuer), + jwks: &jwks, + signing_algorithm: &client.id_token_signed_response_alg.unwrap(), + client_id: &client.client_id, + }; + + verify_id_token( + ¶ms.id_token_hint, + id_token_verification_data, + None, + clock.now(), + )?; + + // Check that the session is still valid. + if !oauth_session.is_valid() { + // If the session is not valid, we redirect to post logout uri + return Ok((cookie_jar, Redirect::to(¶ms.post_logout_redirect_uri)).into_response()); + } + + // Check that the client ending the session is the same as the client that + // created it. + if client.id != oauth_session.client_id { + return Err(RouteError::UnauthorizedClient); + } + + activity_tracker + .record_oauth2_session(&clock, &oauth_session) + .await; + + // If the session is associated with a user, make sure we schedule a device + // deletion job for all the devices associated with the session. + if let Some(user_id) = oauth_session.user_id { + // Fetch the user + let user = repo + .user() + .lookup(user_id) + .await? + .ok_or(RouteError::UnknownToken)?; + + // Schedule a job to sync the devices of the user with the homeserver + repo.queue_job() + .schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user)) + .await?; + } + + // Now that we checked everything, we can end the session. + repo.oauth2_session().finish(&clock, oauth_session).await?; + + activity_tracker + .record_browser_session(&clock, &browser_session) + .await; + repo.browser_session() + .finish(&clock, browser_session) + .await?; + + repo.save().await?; + + // We always want to clear out the session cookie, even if the session was + // invalid + let cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended()); + + Ok((cookie_jar, Redirect::to(¶ms.post_logout_redirect_uri)).into_response()) +} diff --git a/crates/handlers/src/oauth2/mod.rs b/crates/handlers/src/oauth2/mod.rs index cf28818e2..b0518bc2f 100644 --- a/crates/handlers/src/oauth2/mod.rs +++ b/crates/handlers/src/oauth2/mod.rs @@ -25,6 +25,7 @@ use thiserror::Error; pub mod authorization; pub mod device; pub mod discovery; +pub mod end_session; pub mod introspection; pub mod keys; pub mod registration; diff --git a/crates/router/src/endpoints.rs b/crates/router/src/endpoints.rs index 3440f8bc6..703b8e11c 100644 --- a/crates/router/src/endpoints.rs +++ b/crates/router/src/endpoints.rs @@ -155,6 +155,14 @@ impl SimpleRoute for OAuth2AuthorizationEndpoint { const PATH: &'static str = "/authorize"; } +/// `POST /oauth2/end_session` +#[derive(Default, Debug, Clone)] +pub struct OAuth2EndSession; + +impl SimpleRoute for OAuth2EndSession { + const PATH: &'static str = "/oauth2/end-session"; +} + /// `GET /` #[derive(Default, Debug, Clone)] pub struct Index; diff --git a/crates/router/src/url_builder.rs b/crates/router/src/url_builder.rs index f216fb343..14a88924a 100644 --- a/crates/router/src/url_builder.rs +++ b/crates/router/src/url_builder.rs @@ -160,6 +160,12 @@ impl UrlBuilder { self.absolute_url_for(&crate::endpoints::OAuth2Revocation) } + /// OAuth 2.0 revocation endpoint + #[must_use] + pub fn oauth_end_session_endpoint(&self) -> Url { + self.absolute_url_for(&crate::endpoints::OAuth2EndSession) + } + /// OAuth 2.0 client registration endpoint #[must_use] pub fn oauth_registration_endpoint(&self) -> Url { diff --git a/crates/storage-pg/.sqlx/query-39f88cc7c5c4d2e206aa5f2e91c9c2c0abdcf4672438c54b7152733e66bb85cf.json b/crates/storage-pg/.sqlx/query-39f88cc7c5c4d2e206aa5f2e91c9c2c0abdcf4672438c54b7152733e66bb85cf.json new file mode 100644 index 000000000..fd3276baf --- /dev/null +++ b/crates/storage-pg/.sqlx/query-39f88cc7c5c4d2e206aa5f2e91c9c2c0abdcf4672438c54b7152733e66bb85cf.json @@ -0,0 +1,82 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT oauth2_session_id\n , user_id\n , user_session_id\n , oauth2_client_id\n , scope_list\n , created_at\n , finished_at\n , user_agent\n , last_active_at\n , last_active_ip as \"last_active_ip: IpAddr\"\n , human_name\n FROM oauth2_sessions\n\n WHERE user_session_id = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "oauth2_session_id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "user_id", + "type_info": "Uuid" + }, + { + "ordinal": 2, + "name": "user_session_id", + "type_info": "Uuid" + }, + { + "ordinal": 3, + "name": "oauth2_client_id", + "type_info": "Uuid" + }, + { + "ordinal": 4, + "name": "scope_list", + "type_info": "TextArray" + }, + { + "ordinal": 5, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 6, + "name": "finished_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 7, + "name": "user_agent", + "type_info": "Text" + }, + { + "ordinal": 8, + "name": "last_active_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 9, + "name": "last_active_ip: IpAddr", + "type_info": "Inet" + }, + { + "ordinal": 10, + "name": "human_name", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + true, + true, + false, + false, + false, + true, + true, + true, + true, + true + ] + }, + "hash": "39f88cc7c5c4d2e206aa5f2e91c9c2c0abdcf4672438c54b7152733e66bb85cf" +} diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs index 072691a06..9550757c6 100644 --- a/crates/storage-pg/src/oauth2/session.rs +++ b/crates/storage-pg/src/oauth2/session.rs @@ -592,4 +592,43 @@ impl OAuth2SessionRepository for PgOAuth2SessionRepository<'_> { Ok(session) } + + #[tracing::instrument( + name = "db.oauth2_session.find_by_browser_session", + skip_all, + fields( + db.query.text, + session.id = %id, + ), + err, + )] + async fn find_by_browser_session(&mut self, id: Ulid) -> Result, Self::Error> { + let res = sqlx::query_as!( + OAuthSessionLookup, + r#" + SELECT oauth2_session_id + , user_id + , user_session_id + , oauth2_client_id + , scope_list + , created_at + , finished_at + , user_agent + , last_active_at + , last_active_ip as "last_active_ip: IpAddr" + , human_name + FROM oauth2_sessions + + WHERE user_session_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .fetch_optional(&mut *self.conn) + .await?; + + let Some(session) = res else { return Ok(None) }; + + Ok(Some(session.try_into()?)) + } } diff --git a/crates/storage/src/oauth2/session.rs b/crates/storage/src/oauth2/session.rs index 30ac1abe1..b2b9b23a3 100644 --- a/crates/storage/src/oauth2/session.rs +++ b/crates/storage/src/oauth2/session.rs @@ -461,6 +461,19 @@ pub trait OAuth2SessionRepository: Send + Sync { session: Session, human_name: Option, ) -> Result; + + /// Lookup an [`Session`] by its browser session id + /// + /// Returns `None` if no [`Session`] was found + /// + /// # Parameters + /// + /// * `id`: The ID of the [`Session`] to lookup + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn find_by_browser_session(&mut self, id: Ulid) -> Result, Self::Error>; } repository_impl!(OAuth2SessionRepository: @@ -526,4 +539,6 @@ repository_impl!(OAuth2SessionRepository: session: Session, human_name: Option, ) -> Result; + + async fn find_by_browser_session(&mut self, id: Ulid) -> Result, Self::Error>; );