Skip to content

Commit 2b6c4f9

Browse files
committed
feat(auth, iam): support group federation
1 parent 4fc0af9 commit 2b6c4f9

File tree

8 files changed

+117
-1
lines changed

8 files changed

+117
-1
lines changed

redshift_connector/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def connect(
183183
is_serverless: typing.Optional[bool] = False,
184184
serverless_acct_id: typing.Optional[str] = None,
185185
serverless_work_group: typing.Optional[str] = None,
186+
group_federation: typing.Optional[bool] = None,
186187
) -> Connection:
187188
"""
188189
Establishes a :class:`Connection` to an Amazon Redshift cluster. This function validates user input, optionally authenticates using an identity provider plugin, then constructs a :class:`Connection` object.
@@ -284,6 +285,8 @@ def connect(
284285
The account ID of the serverless. Default value None
285286
serverless_work_group: Optional[str]
286287
The name of work group for serverless end point. Default value None.
288+
group_federation: Optional[bool]
289+
Use the IDP Groups in the Redshift. Default value False.
287290
Returns
288291
-------
289292
A Connection object associated with the specified Amazon Redshift cluster: :class:`Connection`
@@ -307,6 +310,7 @@ def connect(
307310
info.put("db_user", db_user)
308311
info.put("endpoint_url", endpoint_url)
309312
info.put("force_lowercase", force_lowercase)
313+
info.put("group_federation", group_federation)
310314
info.put("host", host)
311315
info.put("iam", iam)
312316
info.put("iam_disable_cache", iam_disable_cache)

redshift_connector/iam_helper.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import enum
23
import logging
34
import typing
45

@@ -25,9 +26,33 @@
2526

2627

2728
class IamHelper(IdpAuthHelper):
29+
class GetClusterCredentialsAPIType(enum.Enum):
30+
SERVERLESS_V1 = "get_credentials()"
31+
IAM_V1 = "get_cluster_credentials()"
32+
IAM_V2 = "get_cluster_credentials_with_iam()"
33+
34+
@staticmethod
35+
def can_support_v2(info: RedshiftProperty):
36+
return Version(pkg_resources.get_distribution("boto3").version) >= Version("1.24.5")
2837

2938
credentials_cache: typing.Dict[str, dict] = {}
3039

40+
@staticmethod
41+
def get_cluster_credentials_api_type(info: RedshiftProperty):
42+
if not info._is_serverless:
43+
if not info.group_federation:
44+
return IamHelper.GetClusterCredentialsAPIType.IAM_V1
45+
elif (not info.credentials_provider) and IamHelper.GetClusterCredentialsAPIType.can_support_v2(info):
46+
return IamHelper.GetClusterCredentialsAPIType.IAM_V2
47+
else:
48+
raise InterfaceError("Authentication with plugin is not supported for group federation")
49+
elif not info.group_federation:
50+
return IamHelper.GetClusterCredentialsAPIType.SERVERLESS_V1
51+
elif (not info.credentials_provider) and IamHelper.GetClusterCredentialsAPIType.can_support_v2(info):
52+
return IamHelper.GetClusterCredentialsAPIType.IAM_V2
53+
else:
54+
raise InterfaceError("Authentication with plugin is not supported for group federation")
55+
3156
@staticmethod
3257
def set_iam_properties(info: RedshiftProperty) -> RedshiftProperty:
3358
"""
@@ -225,7 +250,12 @@ def set_cluster_credentials(
225250
# retries will occur by default ref:
226251
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#legacy-retry-mode
227252
_logger.debug("Credentials expired or not found...requesting from boto")
228-
if info._is_serverless:
253+
get_creds_api_version: IamHelper.GetClusterCredentialsAPIType = (
254+
IamHelper.get_cluster_credentials_api_type(info)
255+
)
256+
_logger.debug("boto3 get_credentials api version: {} will be used".format(get_creds_api_version.value))
257+
258+
if get_creds_api_version == IamHelper.GetClusterCredentialsAPIType.SERVERLESS_V1:
229259
get_cred_args: typing.Dict[str, str] = {"dbName": info.db_name}
230260
if info.serverless_work_group:
231261
get_cred_args["workgroupName"] = info.serverless_work_group
@@ -237,6 +267,15 @@ def set_cluster_credentials(
237267
# re-map expiration for compatibility with redshift credential response
238268
cred["Expiration"] = cred["expiration"]
239269
del cred["expiration"]
270+
elif get_creds_api_version == IamHelper.GetClusterCredentialsAPIType.IAM_V2:
271+
cred = typing.cast(
272+
typing.Dict[str, typing.Union[str, datetime.datetime]],
273+
client.get_cluster_credentials_with_iam(
274+
DbName=info.db_name,
275+
ClusterIdentifier=info.cluster_identifier,
276+
DurationSeconds=info.duration,
277+
),
278+
)
240279
else:
241280
cred = typing.cast(
242281
typing.Dict[str, typing.Union[str, datetime.datetime]],

redshift_connector/plugin/i_plugin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,7 @@ def refresh(self: "IPlugin") -> None:
4242
Refreshes the credentials, stored in :class:NativeTokenHolder, for the current plugin.
4343
"""
4444
pass # pragma: no cover
45+
46+
@abstractmethod
47+
def set_group_federation(self: "IPlugin", group_federation: bool):
48+
pass

redshift_connector/plugin/jwt_credentials_provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,14 @@ def add_parameter(
4141
self.provider_name = info.provider_name
4242
self.ssl_insecure = info.ssl_insecure
4343
self.disable_cache = info.iam_disable_cache
44+
self.group_federation = False
4445

4546
if info.role_session_name is not None:
4647
self.role_session_name = info.role_session_name
4748

49+
def set_group_federation(self: "JwtCredentialsProvider", group_federation: bool):
50+
self.group_federation = group_federation
51+
4852
def get_credentials(self: "JwtCredentialsProvider") -> NativeTokenHolder:
4953
credentials: typing.Optional[NativeTokenHolder] = None
5054

redshift_connector/plugin/saml_credentials_provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(self: "SamlCredentialsProvider") -> None:
3535
self.auto_create: typing.Optional[bool] = None
3636
self.region: typing.Optional[str] = None
3737
self.principal: typing.Optional[str] = None
38+
self.group_federation: bool = False
3839

3940
self.cache: dict = {}
4041

@@ -53,6 +54,9 @@ def add_parameter(self: "SamlCredentialsProvider", info: RedshiftProperty) -> No
5354
self.region = info.region
5455
self.principal = info.principal
5556

57+
def set_group_federation(self: "SamlCredentialsProvider", group_federation: bool):
58+
self.group_federation = group_federation
59+
5660
def get_sub_type(self) -> int:
5761
return IdpAuthHelper.SAML_PLUGIN
5862

redshift_connector/redshift_property.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def __init__(self: "RedshiftProperty", **kwargs):
117117
self.is_serverless: bool = False
118118
self.serverless_acct_id: typing.Optional[str] = None
119119
self.serverless_work_group: typing.Optional[str] = None
120+
self.group_federation: bool = False
120121

121122
else:
122123
for k, v in kwargs.items():

test/manual/auth/test_aws_credentials.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,34 @@ def test_use_aws_credentials_default_profile():
3131
) as con:
3232
with con.cursor() as cursor:
3333
cursor.execute("select 1")
34+
35+
36+
"""
37+
How to use:
38+
0) Generate credentials using instructions: https://docs.aws.amazon.com/sdk-for-javascript/v2/developer-guide/getting-your-credentials.html
39+
1) In the connect method below, specify the connection parameters
40+
3) Specify the AWS IAM credentials in the variables above
41+
4) Update iam_helper.py to include correct min version. line `Version(pkg_resources.get_distribution("boto3").version) > Version("9.99.9999"):`
42+
5) Manually execute this test
43+
"""
44+
45+
46+
@pytest.mark.skip(reason="manual")
47+
def test_use_get_cluster_credentials_with_iam(db_kwargs):
48+
role_name = "groupFederationTest"
49+
with redshift_connector.connect(**db_kwargs) as conn:
50+
with conn.cursor() as cursor:
51+
# https://docs.aws.amazon.com/redshift/latest/dg/r_CREATE_USER.html
52+
cursor.execute('create user "IAMR:{}" with password disable;'.format(role_name))
53+
with redshift_connector.connect(
54+
iam=True,
55+
database="replace_me",
56+
cluster_identifier="replace_me",
57+
region="replace_me",
58+
profile="replace_me", # contains credentials for AssumeRole groupFederationTest
59+
group_federation=True,
60+
) as con:
61+
with con.cursor() as cursor:
62+
cursor.execute("select 1")
63+
cursor.execute("select current_user")
64+
assert cursor.fetchone()[0] == role_name

test/unit/test_iam_helper.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,35 @@ def test_set_cluster_credentials_refreshes_stale_credentials(
670670
)
671671

672672

673+
@pytest.mark.parametrize(
674+
"conn_params, exp_result",
675+
(
676+
({"credentials_provider": "BrowserSamlCredentialsProvider"}, IamHelper.GetClusterCredentialsAPIType.IAM_V1),
677+
({"group_federation": True}, IamHelper.GetClusterCredentialsAPIType.IAM_V2),
678+
({"is_serverless": True}, IamHelper.GetClusterCredentialsAPIType.SERVERLESS_V1),
679+
({"is_serverless": True, "group_federation": True}, IamHelper.GetClusterCredentialsAPIType.IAM_V2),
680+
(
681+
{"group_federation": True, "credentials_provider": "BrowserSamlCredentialsProvider"},
682+
"Authentication with plugin is not supported for group federation",
683+
),
684+
(
685+
{"is_serverless": True, "group_federation": True, "credentials_provider": "BrowserSamlCredentialsProvider"},
686+
"Authentication with plugin is not supported for group federation",
687+
),
688+
),
689+
)
690+
def test_get_cluster_credentials_api_type_will_use_correct_api(conn_params, exp_result):
691+
info = RedshiftProperty()
692+
for param in conn_params.items():
693+
info.put(param[0], param[1])
694+
695+
if isinstance(exp_result, IamHelper.GetClusterCredentialsAPIType):
696+
assert IamHelper.get_cluster_credentials_api_type(info) == exp_result
697+
else:
698+
with pytest.raises(InterfaceError, match=exp_result):
699+
IamHelper.get_cluster_credentials_api_type(info)
700+
701+
673702
@pytest.mark.parametrize(
674703
"boto3_version",
675704
(

0 commit comments

Comments
 (0)