Skip to content

Commit edeab0f

Browse files
committed
refactor(auth, iam): support group federation
1 parent 2b6c4f9 commit edeab0f

File tree

2 files changed

+140
-14
lines changed

2 files changed

+140
-14
lines changed

redshift_connector/iam_helper.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,59 @@
2626

2727

2828
class IamHelper(IdpAuthHelper):
29+
class IAMAuthenticationType(enum.Enum):
30+
"""
31+
Defines authentication types supported by redshift-connector
32+
"""
33+
34+
NONE = enum.auto()
35+
PROFILE = enum.auto()
36+
IAM_KEYS_WITH_SESSION = enum.auto()
37+
IAM_KEYS = enum.auto()
38+
PLUGIN = enum.auto()
39+
2940
class GetClusterCredentialsAPIType(enum.Enum):
41+
"""
42+
Defines supported Python SDK methods used for Redshift credential retrieval
43+
"""
44+
3045
SERVERLESS_V1 = "get_credentials()"
3146
IAM_V1 = "get_cluster_credentials()"
3247
IAM_V2 = "get_cluster_credentials_with_iam()"
3348

3449
@staticmethod
35-
def can_support_v2(info: RedshiftProperty):
36-
return Version(pkg_resources.get_distribution("boto3").version) >= Version("1.24.5")
50+
def can_support_v2(provider_type: "IamHelper.IAMAuthenticationType") -> bool:
51+
"""
52+
Determines if user provided connection options and boto3 version support group federation.
53+
"""
54+
return (
55+
provider_type
56+
in (
57+
IamHelper.IAMAuthenticationType.PROFILE,
58+
IamHelper.IAMAuthenticationType.IAM_KEYS,
59+
IamHelper.IAMAuthenticationType.IAM_KEYS_WITH_SESSION,
60+
)
61+
) and Version(pkg_resources.get_distribution("boto3").version) >= Version("1.24.5")
3762

3863
credentials_cache: typing.Dict[str, dict] = {}
3964

4065
@staticmethod
41-
def get_cluster_credentials_api_type(info: RedshiftProperty):
66+
def get_cluster_credentials_api_type(
67+
info: RedshiftProperty, provider_type: "IamHelper.IAMAuthenticationType"
68+
) -> GetClusterCredentialsAPIType:
69+
"""
70+
Returns an enum representing the Python SDK method to use for getting temporary IAM credentials.
71+
"""
4272
if not info._is_serverless:
4373
if not info.group_federation:
4474
return IamHelper.GetClusterCredentialsAPIType.IAM_V1
45-
elif (not info.credentials_provider) and IamHelper.GetClusterCredentialsAPIType.can_support_v2(info):
75+
elif IamHelper.GetClusterCredentialsAPIType.can_support_v2(provider_type):
4676
return IamHelper.GetClusterCredentialsAPIType.IAM_V2
4777
else:
4878
raise InterfaceError("Authentication with plugin is not supported for group federation")
4979
elif not info.group_federation:
5080
return IamHelper.GetClusterCredentialsAPIType.SERVERLESS_V1
51-
elif (not info.credentials_provider) and IamHelper.GetClusterCredentialsAPIType.can_support_v2(info):
81+
elif IamHelper.GetClusterCredentialsAPIType.can_support_v2(provider_type):
5282
return IamHelper.GetClusterCredentialsAPIType.IAM_V2
5383
else:
5484
raise InterfaceError("Authentication with plugin is not supported for group federation")
@@ -59,6 +89,7 @@ def set_iam_properties(info: RedshiftProperty) -> RedshiftProperty:
5989
Helper function to handle connection properties and ensure required parameters are specified.
6090
Parameters
6191
"""
92+
provider_type: IamHelper.IAMAuthenticationType = IamHelper.IAMAuthenticationType.NONE
6293
# set properties present for both IAM, Native authentication
6394
IamHelper.set_auth_properties(info)
6495

@@ -186,6 +217,26 @@ def get_credentials_cache_key(info: RedshiftProperty, cred_provider: typing.Unio
186217
)
187218
)
188219

220+
@staticmethod
221+
def get_authentication_type(
222+
provider: typing.Union[IPlugin, AWSCredentialsProvider]
223+
) -> "IamHelper.IAMAuthenticationType":
224+
"""
225+
Returns an enum representing the type of authentication the user is requesting based on connection parameters.
226+
"""
227+
provider_type: IamHelper.IAMAuthenticationType = IamHelper.IAMAuthenticationType.NONE
228+
if isinstance(provider, IPlugin):
229+
provider_type = IamHelper.IAMAuthenticationType.PLUGIN
230+
elif isinstance(provider, AWSCredentialsProvider):
231+
if provider.profile is not None:
232+
provider_type = IamHelper.IAMAuthenticationType.PROFILE
233+
elif provider.session_token is not None:
234+
provider_type = IamHelper.IAMAuthenticationType.IAM_KEYS_WITH_SESSION
235+
else:
236+
provider_type = IamHelper.IAMAuthenticationType.IAM_KEYS
237+
238+
return provider_type
239+
189240
@staticmethod
190241
def set_cluster_credentials(
191242
cred_provider: typing.Union[IPlugin, AWSCredentialsProvider], info: RedshiftProperty
@@ -250,8 +301,9 @@ def set_cluster_credentials(
250301
# retries will occur by default ref:
251302
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#legacy-retry-mode
252303
_logger.debug("Credentials expired or not found...requesting from boto")
304+
provider_type: IamHelper.IAMAuthenticationType = IamHelper.get_authentication_type(cred_provider)
253305
get_creds_api_version: IamHelper.GetClusterCredentialsAPIType = (
254-
IamHelper.get_cluster_credentials_api_type(info)
306+
IamHelper.get_cluster_credentials_api_type(info, provider_type)
255307
)
256308
_logger.debug("boto3 get_credentials api version: {} will be used".format(get_creds_api_version.value))
257309

test/unit/test_iam_helper.py

Lines changed: 82 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -670,33 +670,107 @@ def test_set_cluster_credentials_refreshes_stale_credentials(
670670
)
671671

672672

673+
def test_get_authentication_type_for_iam_with_profile():
674+
provider = AWSCredentialsProvider()
675+
provider.profile = "test"
676+
assert IamHelper.get_authentication_type(provider) == IamHelper.IAMAuthenticationType.PROFILE
677+
678+
679+
def test_get_authentication_type_for_iam_with_key_session():
680+
provider = AWSCredentialsProvider()
681+
provider.access_key_id = "test_key"
682+
provider.session_token = "test_token"
683+
provider.secret_access_key = "test_secret_key"
684+
assert IamHelper.get_authentication_type(provider) == IamHelper.IAMAuthenticationType.IAM_KEYS_WITH_SESSION
685+
686+
687+
def test_get_authentication_type_for_iam_with_key():
688+
provider = AWSCredentialsProvider()
689+
provider.access_key_id = "test_key"
690+
provider.secret_access_key = "test_secret_key"
691+
assert IamHelper.get_authentication_type(provider) == IamHelper.IAMAuthenticationType.IAM_KEYS
692+
693+
694+
def test_get_authentication_type_for_iam_with_plugin():
695+
provider = BrowserSamlCredentialsProvider()
696+
assert IamHelper.get_authentication_type(provider) == IamHelper.IAMAuthenticationType.PLUGIN
697+
698+
673699
@pytest.mark.parametrize(
674-
"conn_params, exp_result",
700+
"conn_params, provider, exp_result",
675701
(
676-
({"credentials_provider": "BrowserSamlCredentialsProvider"}, IamHelper.GetClusterCredentialsAPIType.IAM_V1),
677-
({"group_federation": True}, IamHelper.GetClusterCredentialsAPIType.IAM_V2),
678-
({"is_serverless": True}, IamHelper.GetClusterCredentialsAPIType.SERVERLESS_V1),
679-
({"is_serverless": True, "group_federation": True}, IamHelper.GetClusterCredentialsAPIType.IAM_V2),
702+
(
703+
{"credentials_provider": "BrowserSamlCredentialsProvider"},
704+
IamHelper.IAMAuthenticationType.PLUGIN,
705+
IamHelper.GetClusterCredentialsAPIType.IAM_V1,
706+
),
707+
(
708+
{"group_federation": True},
709+
IamHelper.IAMAuthenticationType.PROFILE,
710+
IamHelper.GetClusterCredentialsAPIType.IAM_V2,
711+
),
712+
(
713+
{"is_serverless": True},
714+
IamHelper.IAMAuthenticationType.PROFILE,
715+
IamHelper.GetClusterCredentialsAPIType.SERVERLESS_V1,
716+
),
717+
(
718+
{"is_serverless": True, "group_federation": True},
719+
IamHelper.IAMAuthenticationType.IAM_KEYS,
720+
IamHelper.GetClusterCredentialsAPIType.IAM_V2,
721+
),
722+
(
723+
{"group_federation": True},
724+
IamHelper.IAMAuthenticationType.IAM_KEYS,
725+
IamHelper.GetClusterCredentialsAPIType.IAM_V2,
726+
),
727+
(
728+
{"is_serverless": True},
729+
IamHelper.IAMAuthenticationType.IAM_KEYS,
730+
IamHelper.GetClusterCredentialsAPIType.SERVERLESS_V1,
731+
),
732+
(
733+
{"is_serverless": True, "group_federation": True},
734+
IamHelper.IAMAuthenticationType.IAM_KEYS_WITH_SESSION,
735+
IamHelper.GetClusterCredentialsAPIType.IAM_V2,
736+
),
737+
(
738+
{"group_federation": True},
739+
IamHelper.IAMAuthenticationType.IAM_KEYS_WITH_SESSION,
740+
IamHelper.GetClusterCredentialsAPIType.IAM_V2,
741+
),
742+
(
743+
{"is_serverless": True},
744+
IamHelper.IAMAuthenticationType.IAM_KEYS_WITH_SESSION,
745+
IamHelper.GetClusterCredentialsAPIType.SERVERLESS_V1,
746+
),
747+
(
748+
{"is_serverless": True, "group_federation": True},
749+
IamHelper.IAMAuthenticationType.PROFILE,
750+
IamHelper.GetClusterCredentialsAPIType.IAM_V2,
751+
),
680752
(
681753
{"group_federation": True, "credentials_provider": "BrowserSamlCredentialsProvider"},
754+
IamHelper.IAMAuthenticationType.PLUGIN,
682755
"Authentication with plugin is not supported for group federation",
683756
),
684757
(
685758
{"is_serverless": True, "group_federation": True, "credentials_provider": "BrowserSamlCredentialsProvider"},
759+
IamHelper.IAMAuthenticationType.PLUGIN,
686760
"Authentication with plugin is not supported for group federation",
687761
),
688762
),
689763
)
690-
def test_get_cluster_credentials_api_type_will_use_correct_api(conn_params, exp_result):
764+
def test_get_cluster_credentials_api_type_will_use_correct_api(conn_params, provider, exp_result):
691765
info = RedshiftProperty()
692766
for param in conn_params.items():
693767
info.put(param[0], param[1])
694768

695769
if isinstance(exp_result, IamHelper.GetClusterCredentialsAPIType):
696-
assert IamHelper.get_cluster_credentials_api_type(info) == exp_result
770+
assert IamHelper.get_cluster_credentials_api_type(info, provider) == exp_result
697771
else:
698772
with pytest.raises(InterfaceError, match=exp_result):
699-
IamHelper.get_cluster_credentials_api_type(info)
773+
IamHelper.get_cluster_credentials_api_type(info, provider)
700774

701775

702776
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)