|
1 | 1 | import logging |
2 | 2 | from typing import Callable |
| 3 | +from typing import List |
3 | 4 | from typing import Optional |
4 | 5 |
|
5 | 6 | from cryptojwt import KeyJar |
|
8 | 9 |
|
9 | 10 | from idpyoidc.client.util import get_uri |
10 | 11 | from idpyoidc.impexp import ImpExp |
| 12 | +from idpyoidc.message import Message |
| 13 | +from idpyoidc.transform import preferred_to_registered |
11 | 14 | from idpyoidc.util import add_path |
12 | 15 | from idpyoidc.util import qualified_name |
13 | 16 |
|
14 | 17 | logger = logging.getLogger(__name__) |
15 | 18 |
|
| 19 | + |
16 | 20 | def claims_dump(info, exclude_attributes): |
17 | 21 | return {qualified_name(info.__class__): info.dump(exclude_attributes=exclude_attributes)} |
18 | 22 |
|
@@ -124,7 +128,7 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""): |
124 | 128 |
|
125 | 129 | return keyjar, _uri_path |
126 | 130 |
|
127 | | - def get_base_url(self, configuration: dict, entity_id: Optional[str]=""): |
| 131 | + def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""): |
128 | 132 | raise NotImplementedError() |
129 | 133 |
|
130 | 134 | def get_id(self, configuration: dict): |
@@ -183,6 +187,10 @@ def load_conf( |
183 | 187 | elif val: |
184 | 188 | self.set_preference(key, val) |
185 | 189 |
|
| 190 | + for attr,val in supports.items(): |
| 191 | + if attr not in self.prefer and val is not None: |
| 192 | + self.set_preference(attr,val) |
| 193 | + |
186 | 194 | self.verify_rules(supports) |
187 | 195 | return keyjar |
188 | 196 |
|
@@ -222,3 +230,53 @@ def get_claim(self, key, default=None): |
222 | 230 | return default |
223 | 231 | else: |
224 | 232 | return _val |
| 233 | + |
| 234 | + def get_endpoint_claims(self, endpoints): |
| 235 | + _info = {} |
| 236 | + for endp in endpoints: |
| 237 | + if endp.endpoint_name: |
| 238 | + _info[endp.endpoint_name] = endp.full_path |
| 239 | + for arg, claim in [("client_authn_method", "auth_methods"), |
| 240 | + ("auth_signing_alg_values", "auth_signing_alg_values")]: |
| 241 | + _val = getattr(endp, arg, None) |
| 242 | + if _val: |
| 243 | + # trust_mark_status_endpoint_auth_methods_supported |
| 244 | + md_param = f"{endp.endpoint_name}_{claim}" |
| 245 | + _info[md_param] = _val |
| 246 | + return _info |
| 247 | + |
| 248 | + def get_metadata(self, |
| 249 | + entity_type: Optional[str] = "", |
| 250 | + endpoints: Optional[list] = None, |
| 251 | + metadata_schema: Optional[Message] = None, |
| 252 | + extra_claims: Optional[List[str]] = None, |
| 253 | + supported: Optional[dict] = None, |
| 254 | + **kwargs): |
| 255 | + |
| 256 | + if supported is None: |
| 257 | + supported = self.supports() |
| 258 | + |
| 259 | + if not self.use: |
| 260 | + self.use = preferred_to_registered(self.prefer, supported=supported) |
| 261 | + |
| 262 | + metadata = self.use |
| 263 | + # the claims that can appear in the metadata |
| 264 | + if metadata_schema: |
| 265 | + attr = list(metadata_schema.c_param.keys()) |
| 266 | + else: |
| 267 | + attr = [] |
| 268 | + |
| 269 | + if extra_claims: |
| 270 | + attr.extend(extra_claims) |
| 271 | + |
| 272 | + if attr: |
| 273 | + metadata = {k: v for k, v in metadata.items() if k in attr} |
| 274 | + |
| 275 | + # collect endpoints |
| 276 | + if endpoints: |
| 277 | + metadata.update(self.get_endpoint_claims(endpoints)) |
| 278 | + |
| 279 | + if entity_type: |
| 280 | + return {entity_type: metadata} |
| 281 | + else: |
| 282 | + return metadata |
0 commit comments