@@ -40,113 +40,53 @@ class OpenIDConnectFrontend(FrontendModule):
4040 """
4141
4242 def __init__ (self , auth_req_callback_func , internal_attributes , conf , base_url , name ):
43- self . _validate_config (conf )
43+ _validate_config (conf )
4444 super ().__init__ (auth_req_callback_func , internal_attributes , base_url , name )
4545
4646 self .config = conf
47- self .signing_key = RSAKey (key = rsa_load (conf ["signing_key_path" ]), use = "sig" , alg = "RS256" ,
48- kid = conf .get ("signing_key_id" , "" ))
49-
50- def _create_provider (self , endpoint_baseurl ):
51- response_types_supported = self .config ["provider" ].get ("response_types_supported" , ["id_token" ])
52- subject_types_supported = self .config ["provider" ].get ("subject_types_supported" , ["pairwise" ])
53- scopes_supported = self .config ["provider" ].get ("scopes_supported" , ["openid" ])
54- extra_scopes = self .config ["provider" ].get ("extra_scopes" )
55- capabilities = {
56- "issuer" : self .base_url ,
57- "authorization_endpoint" : "{}/{}" .format (endpoint_baseurl , AuthorizationEndpoint .url ),
58- "jwks_uri" : "{}/jwks" .format (endpoint_baseurl ),
59- "response_types_supported" : response_types_supported ,
60- "id_token_signing_alg_values_supported" : [self .signing_key .alg ],
61- "response_modes_supported" : ["fragment" , "query" ],
62- "subject_types_supported" : subject_types_supported ,
63- "claim_types_supported" : ["normal" ],
64- "claims_parameter_supported" : True ,
65- "claims_supported" : [attribute_map ["openid" ][0 ]
66- for attribute_map in self .internal_attributes ["attributes" ].values ()
67- if "openid" in attribute_map ],
68- "request_parameter_supported" : False ,
69- "request_uri_parameter_supported" : False ,
70- "scopes_supported" : scopes_supported
71- }
72-
73- if 'code' in response_types_supported :
74- capabilities ["token_endpoint" ] = "{}/{}" .format (endpoint_baseurl , TokenEndpoint .url )
75-
76- if self .config ["provider" ].get ("client_registration_supported" , False ):
77- capabilities ["registration_endpoint" ] = "{}/{}" .format (endpoint_baseurl , RegistrationEndpoint .url )
78-
79- authz_state = self ._init_authorization_state ()
47+ provider_config = self .config ["provider" ]
48+ provider_config ["issuer" ] = base_url
49+
50+ self .signing_key = RSAKey (
51+ key = rsa_load (self .config ["signing_key_path" ]),
52+ use = "sig" ,
53+ alg = "RS256" ,
54+ kid = self .config .get ("signing_key_id" , "" ),
55+ )
56+
8057 db_uri = self .config .get ("db_uri" )
58+ self .user_db = (
59+ StorageBase .from_uri (db_uri , db_name = "satosa" , collection = "authz_codes" )
60+ if db_uri
61+ else {}
62+ )
63+
64+ sub_hash_salt = self .config .get ("sub_hash_salt" , rndstr (16 ))
65+ authz_state = _init_authorization_state (provider_config , db_uri , sub_hash_salt )
66+
8167 client_db_uri = self .config .get ("client_db_uri" )
8268 cdb_file = self .config .get ("client_db_path" )
8369 if client_db_uri :
8470 cdb = StorageBase .from_uri (
85- client_db_uri , db_name = "satosa" , collection = "clients"
71+ client_db_uri , db_name = "satosa" , collection = "clients" , ttl = None
8672 )
8773 elif cdb_file :
8874 with open (cdb_file ) as f :
8975 cdb = json .loads (f .read ())
9076 else :
9177 cdb = {}
9278
93- #XXX What is the correct ttl for user_db? Is it the same as authz_code_db?
94- self .user_db = (
95- StorageBase .from_uri (db_uri , db_name = "satosa" , collection = "authz_codes" )
96- if db_uri
97- else {}
98- )
99-
100- self .provider = Provider (
79+ self .endpoint_baseurl = "{}/{}" .format (self .base_url , self .name )
80+ self .provider = _create_provider (
81+ provider_config ,
82+ self .endpoint_baseurl ,
83+ self .internal_attributes ,
10184 self .signing_key ,
102- capabilities ,
10385 authz_state ,
86+ self .user_db ,
10487 cdb ,
105- Userinfo (self .user_db ),
106- extra_scopes = extra_scopes ,
107- id_token_lifetime = self .config ["provider" ].get ("id_token_lifetime" , 3600 ),
10888 )
10989
110- def _init_authorization_state (self ):
111- sub_hash_salt = self .config .get ("sub_hash_salt" , rndstr (16 ))
112- db_uri = self .config .get ("db_uri" )
113- if db_uri :
114- authz_code_db = StorageBase .from_uri (
115- db_uri ,
116- db_name = "satosa" ,
117- collection = "authz_codes" ,
118- ttl = self .config ["provider" ].get ("authorization_code_lifetime" , 600 ),
119- )
120- access_token_db = StorageBase .from_uri (
121- db_uri ,
122- db_name = "satosa" ,
123- collection = "access_tokens" ,
124- ttl = self .config ["provider" ].get ("access_token_lifetime" , 3600 ),
125- )
126- refresh_token_db = StorageBase .from_uri (
127- db_uri ,
128- db_name = "satosa" ,
129- collection = "refresh_tokens" ,
130- ttl = self .config ["provider" ].get ("refresh_token_lifetime" , None ),
131- )
132- #XXX what is the correct TTL for sub_db?
133- sub_db = StorageBase .from_uri (
134- db_uri , db_name = "satosa" , collection = "subject_identifiers"
135- )
136- else :
137- authz_code_db = None
138- access_token_db = None
139- refresh_token_db = None
140- sub_db = None
141-
142- token_lifetimes = {k : self .config ["provider" ][k ] for k in ["authorization_code_lifetime" ,
143- "access_token_lifetime" ,
144- "refresh_token_lifetime" ,
145- "refresh_token_threshold" ]
146- if k in self .config ["provider" ]}
147- return AuthorizationState (HashBasedSubjectIdentifierFactory (sub_hash_salt ), authz_code_db , access_token_db ,
148- refresh_token_db , sub_db , ** token_lifetimes )
149-
15090 def _get_extra_id_token_claims (self , user_id , client_id ):
15191 if "extra_id_token_claims" in self .config ["provider" ]:
15292 config = self .config ["provider" ]["extra_id_token_claims" ].get (client_id , [])
@@ -223,9 +163,6 @@ def register_endpoints(self, backend_names):
223163 else :
224164 backend_name = backend_names [0 ]
225165
226- endpoint_baseurl = "{}/{}" .format (self .base_url , self .name )
227- self ._create_provider (endpoint_baseurl )
228-
229166 provider_config = ("^.well-known/openid-configuration$" , self .provider_config )
230167 jwks_uri = ("^{}/jwks$" .format (self .name ), self .jwks )
231168
@@ -236,42 +173,36 @@ def register_endpoints(self, backend_names):
236173 auth_path = urlparse (auth_endpoint ).path .lstrip ("/" )
237174 else :
238175 auth_path = "{}/{}" .format (self .name , AuthorizationEndpoint .url )
176+
239177 authentication = ("^{}$" .format (auth_path ), self .handle_authn_request )
240178 url_map = [provider_config , jwks_uri , authentication ]
241179
242180 if any ("code" in v for v in self .provider .configuration_information ["response_types_supported" ]):
243- self .provider .configuration_information ["token_endpoint" ] = "{}/{}" .format (endpoint_baseurl ,
244- TokenEndpoint .url )
245- token_endpoint = ("^{}/{}" .format (self .name , TokenEndpoint .url ), self .token_endpoint )
181+ self .provider .configuration_information ["token_endpoint" ] = "{}/{}" .format (
182+ self .endpoint_baseurl , TokenEndpoint .url
183+ )
184+ token_endpoint = (
185+ "^{}/{}" .format (self .name , TokenEndpoint .url ), self .token_endpoint
186+ )
246187 url_map .append (token_endpoint )
247188
248- self .provider .configuration_information ["userinfo_endpoint" ] = "{}/{}" .format (endpoint_baseurl ,
249- UserinfoEndpoint .url )
250- userinfo_endpoint = ("^{}/{}" .format (self .name , UserinfoEndpoint .url ), self .userinfo_endpoint )
189+ self .provider .configuration_information ["userinfo_endpoint" ] = (
190+ "{}/{}" .format (self .endpoint_baseurl , UserinfoEndpoint .url )
191+ )
192+ userinfo_endpoint = (
193+ "^{}/{}" .format (self .name , UserinfoEndpoint .url ), self .userinfo_endpoint
194+ )
251195 url_map .append (userinfo_endpoint )
196+
252197 if "registration_endpoint" in self .provider .configuration_information :
253- client_registration = ("^{}/{}" .format (self .name , RegistrationEndpoint .url ), self .client_registration )
198+ client_registration = (
199+ "^{}/{}" .format (self .name , RegistrationEndpoint .url ),
200+ self .client_registration ,
201+ )
254202 url_map .append (client_registration )
255203
256204 return url_map
257205
258- def _validate_config (self , config ):
259- """
260- Validates that all necessary config parameters are specified.
261- :type config: dict[str, dict[str, Any] | str]
262- :param config: the module config
263- """
264- if config is None :
265- raise ValueError ("OIDCFrontend conf can't be 'None'." )
266-
267- for k in {"signing_key_path" , "provider" }:
268- if k not in config :
269- raise ValueError ("Missing configuration parameter '{}' for OpenID Connect frontend." .format (k ))
270-
271- if "signing_key_id" in config and type (config ["signing_key_id" ]) is not str :
272- raise ValueError (
273- "The configuration parameter 'signing_key_id' is not defined as a string for OpenID Connect frontend." )
274-
275206 def _get_authn_request_from_state (self , state ):
276207 """
277208 Extract the clietns request stoed in the SATOSA state.
@@ -438,6 +369,128 @@ def userinfo_endpoint(self, context):
438369 return response
439370
440371
372+ def _validate_config (config ):
373+ """
374+ Validates that all necessary config parameters are specified.
375+ :type config: dict[str, dict[str, Any] | str]
376+ :param config: the module config
377+ """
378+ if config is None :
379+ raise ValueError ("OIDCFrontend configuration can't be 'None'." )
380+
381+ for k in {"signing_key_path" , "provider" }:
382+ if k not in config :
383+ raise ValueError ("Missing configuration parameter '{}' for OpenID Connect frontend." .format (k ))
384+
385+ if "signing_key_id" in config and type (config ["signing_key_id" ]) is not str :
386+ raise ValueError (
387+ "The configuration parameter 'signing_key_id' is not defined as a string for OpenID Connect frontend." )
388+
389+
390+ def _create_provider (
391+ provider_config ,
392+ endpoint_baseurl ,
393+ internal_attributes ,
394+ signing_key ,
395+ authz_state ,
396+ user_db ,
397+ cdb ,
398+ ):
399+ response_types_supported = provider_config .get ("response_types_supported" , ["id_token" ])
400+ subject_types_supported = provider_config .get ("subject_types_supported" , ["pairwise" ])
401+ scopes_supported = provider_config .get ("scopes_supported" , ["openid" ])
402+ extra_scopes = provider_config .get ("extra_scopes" )
403+ capabilities = {
404+ "issuer" : provider_config ["issuer" ],
405+ "authorization_endpoint" : "{}/{}" .format (endpoint_baseurl , AuthorizationEndpoint .url ),
406+ "jwks_uri" : "{}/jwks" .format (endpoint_baseurl ),
407+ "response_types_supported" : response_types_supported ,
408+ "id_token_signing_alg_values_supported" : [signing_key .alg ],
409+ "response_modes_supported" : ["fragment" , "query" ],
410+ "subject_types_supported" : subject_types_supported ,
411+ "claim_types_supported" : ["normal" ],
412+ "claims_parameter_supported" : True ,
413+ "claims_supported" : [
414+ attribute_map ["openid" ][0 ]
415+ for attribute_map in internal_attributes ["attributes" ].values ()
416+ if "openid" in attribute_map
417+ ],
418+ "request_parameter_supported" : False ,
419+ "request_uri_parameter_supported" : False ,
420+ "scopes_supported" : scopes_supported
421+ }
422+
423+ if 'code' in response_types_supported :
424+ capabilities ["token_endpoint" ] = "{}/{}" .format (
425+ endpoint_baseurl , TokenEndpoint .url
426+ )
427+
428+ if provider_config .get ("client_registration_supported" , False ):
429+ capabilities ["registration_endpoint" ] = "{}/{}" .format (
430+ endpoint_baseurl , RegistrationEndpoint .url
431+ )
432+
433+ provider = Provider (
434+ signing_key ,
435+ capabilities ,
436+ authz_state ,
437+ cdb ,
438+ Userinfo (user_db ),
439+ extra_scopes = extra_scopes ,
440+ id_token_lifetime = provider_config .get ("id_token_lifetime" , 3600 ),
441+ )
442+ return provider
443+
444+
445+ def _init_authorization_state (provider_config , db_uri , sub_hash_salt ):
446+ if db_uri :
447+ authz_code_db = StorageBase .from_uri (
448+ db_uri ,
449+ db_name = "satosa" ,
450+ collection = "authz_codes" ,
451+ ttl = provider_config .get ("authorization_code_lifetime" , 600 ),
452+ )
453+ access_token_db = StorageBase .from_uri (
454+ db_uri ,
455+ db_name = "satosa" ,
456+ collection = "access_tokens" ,
457+ ttl = provider_config .get ("access_token_lifetime" , 3600 ),
458+ )
459+ refresh_token_db = StorageBase .from_uri (
460+ db_uri ,
461+ db_name = "satosa" ,
462+ collection = "refresh_tokens" ,
463+ ttl = provider_config .get ("refresh_token_lifetime" , None ),
464+ )
465+ sub_db = StorageBase .from_uri (
466+ db_uri , db_name = "satosa" , collection = "subject_identifiers" , ttl = None
467+ )
468+ else :
469+ authz_code_db = None
470+ access_token_db = None
471+ refresh_token_db = None
472+ sub_db = None
473+
474+ token_lifetimes = {
475+ k : provider_config [k ]
476+ for k in [
477+ "authorization_code_lifetime" ,
478+ "access_token_lifetime" ,
479+ "refresh_token_lifetime" ,
480+ "refresh_token_threshold" ,
481+ ]
482+ if k in provider_config
483+ }
484+ return AuthorizationState (
485+ HashBasedSubjectIdentifierFactory (sub_hash_salt ),
486+ authz_code_db ,
487+ access_token_db ,
488+ refresh_token_db ,
489+ sub_db ,
490+ ** token_lifetimes ,
491+ )
492+
493+
441494def combine_return_input (values ):
442495 return values
443496
0 commit comments