diff --git a/fastapi_jwt/jwt.py b/fastapi_jwt/jwt.py index d9041ed..6fb430f 100644 --- a/fastapi_jwt/jwt.py +++ b/fastapi_jwt/jwt.py @@ -37,14 +37,9 @@ def utcnow(): ] -class JwtAuthorizationCredentials: - def __init__(self, subject: Dict[str, Any], jti: Optional[str] = None): - self.subject = subject - self.jti = jti - - def __getitem__(self, item: str) -> Any: - return self.subject[item] - +class JwtAuthorizationCredentials(dict): + def __init__(self, payload: Dict[str, Any]): + super().__init__(payload) class JwtAuthBase(ABC): class JwtAccessCookie(APIKeyCookie): @@ -75,6 +70,8 @@ def __init__( algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, + decode_kwargs: Optional[dict] = None, + subject_key: Optional[str] = "sub" ): assert jwt is not None, "python-jose must be installed to use JwtAuth" if places: @@ -93,6 +90,8 @@ def __init__( self.algorithm = algorithm self.access_expires_delta = access_expires_delta or timedelta(minutes=15) self.refresh_expires_delta = refresh_expires_delta or timedelta(days=31) + self.decode_kwargs = decode_kwargs + self.subject_key = subject_key @classmethod def from_other( @@ -103,6 +102,8 @@ def from_other( algorithm: Optional[str] = None, access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, + decode_kwargs: Optional[dict] = None, + subject_key: Optional[str] = "sub" ) -> 'JwtAuthBase': return cls( secret_key=secret_key or other.secret_key, @@ -110,6 +111,8 @@ def from_other( algorithm=algorithm or other.algorithm, access_expires_delta=access_expires_delta or other.access_expires_delta, refresh_expires_delta=refresh_expires_delta or other.refresh_expires_delta, + decode_kwargs=decode_kwargs or other.decode_kwargs, + subject_key=subject_key or other.subject_key ) def _decode(self, token: str) -> Optional[Dict[str, Any]]: @@ -119,6 +122,7 @@ def _decode(self, token: str) -> Optional[Dict[str, Any]]: self.secret_key, algorithms=[self.algorithm], options={"leeway": 10}, + **self.decode_kwargs ) return payload except jwt.ExpiredSignatureError as e: # type: ignore[attr-defined] @@ -146,7 +150,7 @@ def _generate_payload( now = utcnow() return { - "subject": subject.copy(), # main subject + self.subject_key: subject.copy(), # main subject "type": token_type, # 'access' or 'refresh' token "exp": now + expires_delta, # expire time "iat": now, # creation time @@ -264,6 +268,7 @@ def __init__( algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, + subject_key: Optional[str] = "sub" ): super().__init__( secret_key, @@ -283,7 +288,7 @@ async def _get_credentials( if payload: return JwtAuthorizationCredentials( - payload["subject"], payload.get("jti", None) + payload ) return None @@ -296,6 +301,7 @@ def __init__( algorithm: str = jwt.ALGORITHMS.HS256, # type: ignore[attr-defined] access_expires_delta: Optional[timedelta] = None, refresh_expires_delta: Optional[timedelta] = None, + decode_kwargs: Optional[dict] = None ): super().__init__( secret_key=secret_key, @@ -303,8 +309,9 @@ def __init__( auto_error=auto_error, algorithm=algorithm, access_expires_delta=access_expires_delta, - refresh_expires_delta=refresh_expires_delta, + refresh_expires_delta=refresh_expires_delta ) + self.decode_kwargs = decode_kwargs async def __call__( self, bearer: JwtAuthBase.JwtAccessBearer = Security(JwtAccess._bearer) @@ -405,7 +412,7 @@ async def _get_credentials( return None return JwtAuthorizationCredentials( - payload["subject"], payload.get("jti", None) + payload[self.subject_key], payload.get("jti", None) )