1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616
17+ # pylint: disable=too-many-branches
18+
1719"""Credential providers."""
1820
1921from __future__ import annotations
2931from datetime import timedelta
3032from pathlib import Path
3133from typing import Callable , cast
32- from urllib .parse import urlencode , urlsplit
34+ from urllib .parse import urlencode , urlsplit , urlunsplit
3335from xml .etree import ElementTree as ET
3436
3537import certifi
4244
4345from urllib3 .util import Retry , parse_url
4446
45- from minio .helpers import sha256_hash
47+ from minio .helpers import sha256_hash , url_replace
4648from minio .signer import sign_v4_sts
4749from minio .time import from_iso8601utc , to_amz_date , utcnow
4850from minio .xml import find , findtext
@@ -381,6 +383,13 @@ def __init__(
381383 self ,
382384 custom_endpoint : str | None = None ,
383385 http_client : PoolManager | None = None ,
386+ auth_token : str | None = None ,
387+ relative_uri : str | None = None ,
388+ full_uri : str | None = None ,
389+ token_file : str | None = None ,
390+ role_arn : str | None = None ,
391+ role_session_name : str | None = None ,
392+ region : str | None = None ,
384393 ):
385394 self ._custom_endpoint = custom_endpoint
386395 self ._http_client = http_client or PoolManager (
@@ -390,22 +399,41 @@ def __init__(
390399 status_forcelist = [500 , 502 , 503 , 504 ],
391400 ),
392401 )
393- self ._token_file = os .environ .get ("AWS_WEB_IDENTITY_TOKEN_FILE" )
394- self ._aws_region = os .environ .get ("AWS_REGION" )
395- self ._role_arn = os .environ .get ("AWS_ROLE_ARN" )
396- self ._role_session_name = os .environ .get ("AWS_ROLE_SESSION_NAME" )
397- self ._relative_uri = os .environ .get (
398- "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" ,
402+ self ._token = (
403+ os .environ .get ("AWS_CONTAINER_AUTHORIZATION_TOKEN" ) or
404+ auth_token
405+ )
406+ self ._token_file = (
407+ os .environ .get ("AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE" ) or
408+ auth_token
409+ )
410+ self ._identity_file = (
411+ os .environ .get ("AWS_WEB_IDENTITY_TOKEN_FILE" ) or token_file
412+ )
413+ self ._aws_region = os .environ .get ("AWS_REGION" ) or region
414+ self ._role_arn = os .environ .get ("AWS_ROLE_ARN" ) or role_arn
415+ self ._role_session_name = (
416+ os .environ .get ("AWS_ROLE_SESSION_NAME" ) or role_session_name
417+ )
418+ self ._relative_uri = (
419+ os .environ .get ("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI" ) or
420+ relative_uri
399421 )
400422 if self ._relative_uri and not self ._relative_uri .startswith ("/" ):
401423 self ._relative_uri = "/" + self ._relative_uri
402- self ._full_uri = os .environ .get ("AWS_CONTAINER_CREDENTIALS_FULL_URI" )
424+ self ._full_uri = (
425+ os .environ .get ("AWS_CONTAINER_CREDENTIALS_FULL_URI" ) or
426+ full_uri
427+ )
403428 self ._credentials : Credentials | None = None
404429
405- def fetch (self , url : str ) -> Credentials :
406- """Fetch credentials from EC2/ECS. """
407-
408- res = _urlopen (self ._http_client , "GET" , url )
430+ def fetch (
431+ self ,
432+ url : str ,
433+ headers : dict [str , str | list [str ] | tuple [str ]] | None = None ,
434+ ) -> Credentials :
435+ """Fetch credentials from EC2/ECS."""
436+ res = _urlopen (self ._http_client , "GET" , url , headers = headers )
409437 data = json .loads (res .data )
410438 if data .get ("Code" , "Success" ) != "Success" :
411439 raise ValueError (
@@ -428,14 +456,16 @@ def retrieve(self) -> Credentials:
428456 return self ._credentials
429457
430458 url = self ._custom_endpoint
431- if self ._token_file :
459+ if self ._identity_file :
432460 if not url :
433461 url = "https://sts.amazonaws.com"
434462 if self ._aws_region :
435463 url = f"https://sts.{ self ._aws_region } .amazonaws.com"
464+ if self ._aws_region .startswith ("cn-" ):
465+ url += ".cn"
436466
437467 provider = WebIdentityProvider (
438- lambda : _get_jwt_token (cast (str , self ._token_file )),
468+ lambda : _get_jwt_token (cast (str , self ._identity_file )),
439469 url ,
440470 role_arn = self ._role_arn ,
441471 role_session_name = self ._role_session_name ,
@@ -444,30 +474,55 @@ def retrieve(self) -> Credentials:
444474 self ._credentials = provider .retrieve ()
445475 return cast (Credentials , self ._credentials )
446476
477+ headers : dict [str , str | list [str ] | tuple [str ]] | None = None
447478 if self ._relative_uri :
448479 if not url :
449480 url = "http://169.254.170.2" + self ._relative_uri
481+ headers = {"Authorization" : self ._token } if self ._token else None
450482 elif self ._full_uri :
451- if not url :
483+ token = self ._token
484+ if self ._token_file :
452485 url = self ._full_uri
453- _check_loopback_host (url )
486+ with open (self ._token_file , encoding = "utf-8" ) as file :
487+ token = file .read ()
488+ else :
489+ if not url :
490+ url = self ._full_uri
491+ _check_loopback_host (url )
492+ headers = {"Authorization" : token } if token else None
454493 else :
455494 if not url :
456- url = (
457- "http://169.254.169.254" +
458- "/latest/meta-data/iam/security-credentials/"
459- )
460-
461- res = _urlopen (self ._http_client , "GET" , url )
495+ url = "http://169.254.169.254"
496+
497+ # Get IMDS Token
498+ res = _urlopen (
499+ self ._http_client ,
500+ "PUT" ,
501+ url + "/latest/api/token" ,
502+ headers = {"X-aws-ec2-metadata-token-ttl-seconds" : "21600" },
503+ )
504+ token = res .data .decode ("utf-8" )
505+ headers = {"X-aws-ec2-metadata-token" : token } if token else None
506+
507+ # Get role name
508+ res = _urlopen (
509+ self ._http_client ,
510+ "GET" ,
511+ urlunsplit (
512+ url_replace (
513+ urlsplit (url ),
514+ path = "/latest/meta-data/iam/security-credentials/" ,
515+ ),
516+ ),
517+ headers = headers ,
518+ )
462519 role_names = res .data .decode ("utf-8" ).split ("\n " )
463520 if not role_names :
464521 raise ValueError (f"no IAM roles attached to EC2 service { url } " )
465522 url += "/" + role_names [0 ].strip ("\r " )
466-
467523 if not url :
468524 raise ValueError ("url is empty; this should not happen" )
469-
470- self ._credentials = self .fetch (url )
525+ self ._credentials = self .fetch (url , headers = headers )
471526 return self ._credentials
472527
473528
0 commit comments