|
| 1 | +"""Implements the service that can exchange one token for another.""" |
| 2 | +import logging |
| 3 | + |
| 4 | +from idpyoidc.client.oauth2.utils import get_state_parameter |
| 5 | +from idpyoidc.client.service import Service |
| 6 | +from idpyoidc.exception import MissingParameter |
| 7 | +from idpyoidc.exception import MissingRequiredAttribute |
| 8 | +from idpyoidc.message import oauth2 |
| 9 | +from idpyoidc.message.oauth2 import ResponseMessage |
| 10 | +from idpyoidc.time_util import time_sans_frac |
| 11 | + |
| 12 | +LOGGER = logging.getLogger(__name__) |
| 13 | + |
| 14 | + |
| 15 | +class TokenExchange(Service): |
| 16 | + """The token exchange service.""" |
| 17 | + |
| 18 | + msg_type = oauth2.TokenExchangeRequest |
| 19 | + response_cls = oauth2.TokenExchangeResponse |
| 20 | + error_msg = ResponseMessage |
| 21 | + endpoint_name = "token_endpoint" |
| 22 | + synchronous = True |
| 23 | + service_name = "token_exchange" |
| 24 | + default_authn_method = "client_secret_basic" |
| 25 | + http_method = "POST" |
| 26 | + request_body_type = "urlencoded" |
| 27 | + response_body_type = "json" |
| 28 | + |
| 29 | + |
| 30 | + def __init__(self, client_get, conf=None): |
| 31 | + Service.__init__(self, client_get, conf=conf) |
| 32 | + self.pre_construct.append(self.oauth_pre_construct) |
| 33 | + |
| 34 | + def update_service_context(self, resp, key="", **kwargs): |
| 35 | + if "expires_in" in resp: |
| 36 | + resp["__expires_at"] = time_sans_frac() + int(resp["expires_in"]) |
| 37 | + self.client_get("service_context").state.store_item(resp, "token_response", key) |
| 38 | + |
| 39 | + def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs): |
| 40 | + """ |
| 41 | +
|
| 42 | + :param request_args: Initial set of request arguments |
| 43 | + :param kwargs: Extra keyword arguments |
| 44 | + :return: Request arguments |
| 45 | + """ |
| 46 | + if request_args is None: |
| 47 | + request_args = {} |
| 48 | + |
| 49 | + if 'subject_token' not in request_args: |
| 50 | + try: |
| 51 | + _key = get_state_parameter(request_args, kwargs) |
| 52 | + except MissingParameter: |
| 53 | + raise MissingRequiredAttribute("subject_token") |
| 54 | + |
| 55 | + parameters = {'access_token', 'scope'} |
| 56 | + |
| 57 | + _state = self.client_get("service_context").state |
| 58 | + |
| 59 | + _args = _state.extend_request_args( |
| 60 | + {}, oauth2.AuthorizationResponse, "auth_response", _key, parameters |
| 61 | + ) |
| 62 | + _args = _state.extend_request_args( |
| 63 | + _args, oauth2.AccessTokenResponse, "token_response", _key, parameters |
| 64 | + ) |
| 65 | + _args = _state.extend_request_args( |
| 66 | + _args, oauth2.AccessTokenResponse, "refresh_token_response", _key, parameters |
| 67 | + ) |
| 68 | + |
| 69 | + request_args["subject_token"] = _args["access_token"] |
| 70 | + request_args["subject_token_type"] = 'urn:ietf:params:oauth:token-type:access_token' |
| 71 | + if 'scope' not in request_args and "scope" in _args: |
| 72 | + request_args["scope"] = _args["scope"] |
| 73 | + |
| 74 | + return request_args, post_args |
0 commit comments