diff --git a/tests/test_session.py b/tests/test_session.py index 14dcdd24..d995e0f7 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,5 +1,6 @@ import concurrent.futures from datetime import datetime, timezone +import time from unittest.mock import AsyncMock, Mock, patch import jwt @@ -477,6 +478,174 @@ def test_refresh_success_with_aud_claim( assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + @with_jwks_mock + def test_authenticate_with_slightly_expired_jwt_fails_without_leeway( + self, session_constants, mock_user_management + ): + # Create a token that's expired by 5 seconds + current_time = int(time.time()) + + # Create token claims with exp 5 seconds in the past + token_claims = { + **session_constants["TEST_TOKEN_CLAIMS"], + "exp": current_time - 5, # Expired by 5 seconds + "iat": current_time - 60, # Issued 60 seconds ago + } + + slightly_expired_token = jwt.encode( + token_claims, + session_constants["PRIVATE_KEY"], + algorithm="RS256", + ) + + # Prepare sealed session data with the slightly expired token + session_data = Session.seal_data( + { + "access_token": slightly_expired_token, + "user": session_constants["TEST_USER"], + }, + session_constants["COOKIE_PASSWORD"], + ) + + # With default leeway=0, authentication should fail + session = Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data=session_data, + cookie_password=session_constants["COOKIE_PASSWORD"], + jwt_leeway=0, + ) + + response = session.authenticate() + assert response.authenticated is False + assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT + + @with_jwks_mock + def test_authenticate_with_slightly_expired_jwt_succeeds_with_leeway( + self, session_constants, mock_user_management + ): + # Create a token that's expired by 5 seconds + current_time = int(time.time()) + + # Create token claims with exp 5 seconds in the past + token_claims = { + **session_constants["TEST_TOKEN_CLAIMS"], + "exp": current_time - 5, # Expired by 5 seconds + "iat": current_time - 60, # Issued 60 seconds ago + } + + slightly_expired_token = jwt.encode( + token_claims, + session_constants["PRIVATE_KEY"], + algorithm="RS256", + ) + + # Prepare sealed session data with the slightly expired token + session_data = Session.seal_data( + { + "access_token": slightly_expired_token, + "user": session_constants["TEST_USER"], + }, + session_constants["COOKIE_PASSWORD"], + ) + + # With leeway=10, authentication should succeed + session = Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data=session_data, + cookie_password=session_constants["COOKIE_PASSWORD"], + jwt_leeway=10, # 10 seconds leeway + ) + + response = session.authenticate() + assert response.authenticated is True + assert response.session_id == session_constants["TEST_TOKEN_CLAIMS"]["sid"] + + @with_jwks_mock + def test_authenticate_with_significantly_expired_jwt_fails_without_leeway( + self, session_constants, mock_user_management + ): + # Create a token that's expired by 60 seconds + current_time = int(time.time()) + + # Create token claims with exp 60 seconds in the past + token_claims = { + **session_constants["TEST_TOKEN_CLAIMS"], + "exp": current_time - 60, # Expired by 60 seconds + "iat": current_time - 120, # Issued 120 seconds ago + } + + significantly_expired_token = jwt.encode( + token_claims, + session_constants["PRIVATE_KEY"], + algorithm="RS256", + ) + + # Prepare sealed session data with the significantly expired token + session_data = Session.seal_data( + { + "access_token": significantly_expired_token, + "user": session_constants["TEST_USER"], + }, + session_constants["COOKIE_PASSWORD"], + ) + + # With default leeway=0, authentication should fail + session = Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data=session_data, + cookie_password=session_constants["COOKIE_PASSWORD"], + jwt_leeway=0, + ) + + response = session.authenticate() + assert response.authenticated is False + assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT + + @with_jwks_mock + def test_authenticate_with_significantly_expired_jwt_fails_with_insufficient_leeway( + self, session_constants, mock_user_management + ): + # Create a token that's expired by 60 seconds + current_time = int(time.time()) + + # Create token claims with exp 60 seconds in the past + token_claims = { + **session_constants["TEST_TOKEN_CLAIMS"], + "exp": current_time - 60, # Expired by 60 seconds + "iat": current_time - 120, # Issued 120 seconds ago + } + + significantly_expired_token = jwt.encode( + token_claims, + session_constants["PRIVATE_KEY"], + algorithm="RS256", + ) + + # Prepare sealed session data with the significantly expired token + session_data = Session.seal_data( + { + "access_token": significantly_expired_token, + "user": session_constants["TEST_USER"], + }, + session_constants["COOKIE_PASSWORD"], + ) + + # With leeway=10, authentication should still fail (not enough leeway) + session = Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data=session_data, + cookie_password=session_constants["COOKIE_PASSWORD"], + jwt_leeway=10, # 10 seconds leeway is not enough for 60 seconds expiration + ) + + response = session.authenticate() + assert response.authenticated is False + assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT + class TestAsyncSession(SessionFixtures): @pytest.mark.asyncio diff --git a/workos/_base_client.py b/workos/_base_client.py index d805a80a..18f29d26 100644 --- a/workos/_base_client.py +++ b/workos/_base_client.py @@ -26,6 +26,7 @@ class BaseClient(ClientConfiguration): _base_url: str _client_id: str _request_timeout: int + _jwt_leeway: float def __init__( self, @@ -34,6 +35,7 @@ def __init__( client_id: Optional[str], base_url: Optional[str] = None, request_timeout: Optional[int] = None, + jwt_leeway: float = 0, ) -> None: api_key = api_key or os.getenv("WORKOS_API_KEY") if api_key is None: @@ -65,6 +67,8 @@ def __init__( else int(os.getenv("WORKOS_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)) ) + self._jwt_leeway = jwt_leeway + @property @abstractmethod def audit_logs(self) -> AuditLogsModule: ... @@ -127,3 +131,7 @@ def client_id(self) -> str: @property def request_timeout(self) -> int: return self._request_timeout + + @property + def jwt_leeway(self) -> float: + return self._jwt_leeway diff --git a/workos/async_client.py b/workos/async_client.py index 920c08ab..0a3281a8 100644 --- a/workos/async_client.py +++ b/workos/async_client.py @@ -30,12 +30,14 @@ def __init__( client_id: Optional[str] = None, base_url: Optional[str] = None, request_timeout: Optional[int] = None, + jwt_leeway: float = 0, ): super().__init__( api_key=api_key, client_id=client_id, base_url=base_url, request_timeout=request_timeout, + jwt_leeway=jwt_leeway, ) self._http_client = AsyncHTTPClient( api_key=self._api_key, diff --git a/workos/client.py b/workos/client.py index 9c4aa154..d6ccd674 100644 --- a/workos/client.py +++ b/workos/client.py @@ -30,12 +30,14 @@ def __init__( client_id: Optional[str] = None, base_url: Optional[str] = None, request_timeout: Optional[int] = None, + jwt_leeway: float = 0, ): super().__init__( api_key=api_key, client_id=client_id, base_url=base_url, request_timeout=request_timeout, + jwt_leeway=jwt_leeway, ) self._http_client = SyncHTTPClient( api_key=self._api_key, diff --git a/workos/session.py b/workos/session.py index f0ad6c45..0627430f 100644 --- a/workos/session.py +++ b/workos/session.py @@ -35,6 +35,7 @@ class SessionModule(Protocol): cookie_password: str jwks: PyJWKClient jwk_algorithms: List[str] + jwt_leeway: float def __init__( self, @@ -43,6 +44,7 @@ def __init__( client_id: str, session_data: str, cookie_password: str, + jwt_leeway: float = 0, ) -> None: # If the cookie password is not provided, throw an error if cookie_password is None or cookie_password == "": @@ -52,6 +54,7 @@ def __init__( self.client_id = client_id self.session_data = session_data self.cookie_password = cookie_password + self.jwt_leeway = jwt_leeway self.jwks = _get_jwks_client(self.user_management.get_jwks_url()) @@ -91,13 +94,13 @@ def authenticate( signing_key.key, algorithms=self.jwk_algorithms, options={"verify_aud": False}, + leeway=self.jwt_leeway, ) except jwt.exceptions.InvalidTokenError: return AuthenticateWithSessionCookieErrorResponse( authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT, ) - return AuthenticateWithSessionCookieSuccessResponse( authenticated=True, session_id=decoded["sid"], @@ -137,6 +140,20 @@ def get_logout_url(self, return_to: Optional[str] = None) -> str: ) return str(result) + def _is_valid_jwt(self, token: str) -> bool: + try: + signing_key = self.jwks.get_signing_key_from_jwt(token) + jwt.decode( + token, + signing_key.key, + algorithms=self.jwk_algorithms, + options={"verify_aud": False}, + leeway=self.jwt_leeway, + ) + return True + except jwt.exceptions.InvalidTokenError: + return False + @staticmethod def seal_data(data: Dict[str, Any], key: str) -> str: fernet = Fernet(key) @@ -163,6 +180,7 @@ def __init__( client_id: str, session_data: str, cookie_password: str, + jwt_leeway: float = 0, ) -> None: # If the cookie password is not provided, throw an error if cookie_password is None or cookie_password == "": @@ -172,6 +190,7 @@ def __init__( self.client_id = client_id self.session_data = session_data self.cookie_password = cookie_password + self.jwt_leeway = jwt_leeway self.jwks = _get_jwks_client(self.user_management.get_jwks_url()) @@ -224,6 +243,7 @@ def refresh( signing_key.key, algorithms=self.jwk_algorithms, options={"verify_aud": False}, + leeway=self.jwt_leeway, ) return RefreshWithSessionCookieSuccessResponse( @@ -255,6 +275,7 @@ def __init__( client_id: str, session_data: str, cookie_password: str, + jwt_leeway: float = 0, ) -> None: # If the cookie password is not provided, throw an error if cookie_password is None or cookie_password == "": @@ -264,6 +285,7 @@ def __init__( self.client_id = client_id self.session_data = session_data self.cookie_password = cookie_password + self.jwt_leeway = jwt_leeway self.jwks = _get_jwks_client(self.user_management.get_jwks_url()) @@ -316,6 +338,7 @@ async def refresh( signing_key.key, algorithms=self.jwk_algorithms, options={"verify_aud": False}, + leeway=self.jwt_leeway, ) return RefreshWithSessionCookieSuccessResponse( diff --git a/workos/user_management.py b/workos/user_management.py index 4a093b78..16dbbc85 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -891,6 +891,7 @@ def load_sealed_session( client_id=self._http_client.client_id, session_data=sealed_session, cookie_password=cookie_password, + jwt_leeway=self._client_configuration.jwt_leeway, ) def get_user(self, user_id: str) -> User: @@ -1531,6 +1532,7 @@ async def load_sealed_session( client_id=self._http_client.client_id, session_data=sealed_session, cookie_password=cookie_password, + jwt_leeway=self._client_configuration.jwt_leeway, ) async def get_user(self, user_id: str) -> User: