Skip to content

Commit 4e44b9d

Browse files
committed
feat(serverless): support nlb connection
1 parent 43636bd commit 4e44b9d

File tree

7 files changed

+111
-41
lines changed

7 files changed

+111
-41
lines changed

redshift_connector/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ def connect(
180180
provider_name: typing.Optional[str] = None,
181181
scope: typing.Optional[str] = None,
182182
numeric_to_float: typing.Optional[bool] = False,
183+
is_serverless: typing.Optional[bool] = False,
184+
serverless_acct_id: typing.Optional[str] = None,
185+
serverless_work_group: typing.Optional[str] = None,
183186
) -> Connection:
184187
"""
185188
Establishes a :class:`Connection` to an Amazon Redshift cluster. This function validates user input, optionally authenticates using an identity provider plugin, then constructs a :class:`Connection` object.
@@ -275,6 +278,12 @@ def connect(
275278
Scope for BrowserAzureOauth2CredentialsProvider authentication.
276279
numeric_to_float: Optional[str]
277280
Specifies if NUMERIC datatype values will be converted from ``decimal.Decimal`` to ``float``. By default NUMERIC values are received as ``decimal.Decimal``.
281+
is_serverless: Optional[bool]
282+
Redshift end-point is serverless or provisional. Default value false.
283+
serverless_acct_id: Optional[str]
284+
The account ID of the serverless. Default value None
285+
serverless_work_group: Optional[str]
286+
The name of work group for serverless end point. Default value None.
278287
Returns
279288
-------
280289
A Connection object associated with the specified Amazon Redshift cluster: :class:`Connection`
@@ -304,6 +313,7 @@ def connect(
304313
info.put("idp_host", idp_host)
305314
info.put("idp_response_timeout", idp_response_timeout)
306315
info.put("idp_tenant", idp_tenant)
316+
info.put("is_serverless", is_serverless)
307317
info.put("listen_port", listen_port)
308318
info.put("login_url", login_url)
309319
info.put("max_prepared_statements", max_prepared_statements)
@@ -321,6 +331,8 @@ def connect(
321331
info.put("role_session_name", role_session_name)
322332
info.put("scope", scope)
323333
info.put("secret_access_key", secret_access_key)
334+
info.put("serverless_acct_id", serverless_acct_id)
335+
info.put("serverless_work_group", serverless_work_group)
324336
info.put("session_token", session_token)
325337
info.put("source_address", source_address)
326338
info.put("ssl", ssl)

redshift_connector/iam_helper.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import logging
33
import typing
44

5+
import pkg_resources
56
from dateutil.tz import tzutc
7+
from packaging.version import Version
68

79
from redshift_connector.auth.aws_credentials_provider import AWSCredentialsProvider
810
from redshift_connector.credentials_holder import (
@@ -35,21 +37,24 @@ def set_iam_properties(info: RedshiftProperty) -> RedshiftProperty:
3537
# set properties present for both IAM, Native authentication
3638
IamHelper.set_auth_properties(info)
3739

38-
if info.is_serverless_host and info.iam:
39-
raise ProgrammingError("This feature is not yet available")
40-
# if Version(pkg_resources.get_distribution("boto3").version) <= Version("1.20.22"):
41-
# raise pkg_resources.VersionConflict(
42-
# "boto3 >= XXX required for authentication with Amazon Redshift serverless. "
43-
# "Please upgrade the installed version of boto3 to use this functionality."
44-
# )
40+
if info._is_serverless and info.iam:
41+
if Version(pkg_resources.get_distribution("boto3").version) < Version("1.24.5"):
42+
raise pkg_resources.VersionConflict(
43+
"boto3 >= 1.24.5 required for authentication with Amazon Redshift serverless. "
44+
"Please upgrade the installed version of boto3 to use this functionality."
45+
)
4546

4647
if info.is_serverless_host:
47-
info.set_account_id_from_host()
48-
info.set_region_from_host()
49-
info.set_work_group_from_host()
48+
# consider overridden connection parameters
49+
if not info.region:
50+
info.set_region_from_host()
51+
if not info.serverless_acct_id:
52+
info.set_serverless_acct_id()
53+
if not info.serverless_work_group:
54+
info.set_serverless_work_group_from_host()
5055

5156
if info.iam is True:
52-
if info.cluster_identifier is None and not info.is_serverless_host:
57+
if info.cluster_identifier is None and not info._is_serverless:
5358
raise InterfaceError(
5459
"Invalid connection property setting. cluster_identifier must be provided when IAM is enabled"
5560
)
@@ -136,8 +141,10 @@ def get_credentials_cache_key(info: RedshiftProperty, cred_provider: typing.Unio
136141
typing.cast(str, info.db_user if info.db_user else info.user_name),
137142
info.db_name,
138143
db_groups,
139-
typing.cast(str, info.account_id if info.is_serverless_host else info.cluster_identifier),
140-
typing.cast(str, info.work_group if info.is_serverless_host and info.work_group else ""),
144+
typing.cast(str, info.serverless_acct_id if info._is_serverless else info.cluster_identifier),
145+
typing.cast(
146+
str, info.serverless_work_group if info._is_serverless and info.serverless_work_group else ""
147+
),
141148
str(info.auto_create),
142149
str(info.duration),
143150
# v2 api parameters
@@ -171,7 +178,7 @@ def set_cluster_credentials(
171178
] = cred_provider.get_credentials() # type: ignore
172179
session_credentials: typing.Dict[str, str] = credentials_holder.get_session_credentials()
173180

174-
redshift_client: str = "redshift-serverless" if info.is_serverless_host else "redshift"
181+
redshift_client: str = "redshift-serverless" if info._is_serverless else "redshift"
175182
_logger.debug("boto3.client(service_name={}) being used for IAM auth".format(redshift_client))
176183

177184
for opt_key, opt_val in (("region_name", info.region), ("endpoint_url", info.endpoint_url)):
@@ -190,7 +197,7 @@ def set_cluster_credentials(
190197
if info.host is None or info.host == "" or info.port is None or info.port == "":
191198
response: dict
192199

193-
if info.is_serverless_host:
200+
if info._is_serverless:
194201
response = client.describe_configuration()
195202
info.put("host", response["endpoint"]["address"])
196203
info.put("port", response["endpoint"]["port"])
@@ -218,10 +225,10 @@ def set_cluster_credentials(
218225
# retries will occur by default ref:
219226
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#legacy-retry-mode
220227
_logger.debug("Credentials expired or not found...requesting from boto")
221-
if info.is_serverless_host:
228+
if info._is_serverless:
222229
get_cred_args: typing.Dict[str, str] = {"dbName": info.db_name}
223-
if info.work_group:
224-
get_cred_args["workgroupName"] = info.work_group
230+
if info.serverless_work_group:
231+
get_cred_args["workgroupName"] = info.serverless_work_group
225232

226233
cred = typing.cast(
227234
typing.Dict[str, typing.Union[str, datetime.datetime]],
@@ -247,7 +254,7 @@ def set_cluster_credentials(
247254
typing.Dict[str, typing.Union[str, datetime.datetime]], cred
248255
)
249256
# redshift-serverless api json response payload slightly differs
250-
if info.is_serverless_host:
257+
if info._is_serverless:
251258
info.put("user_name", typing.cast(str, cred["dbUser"]))
252259
info.put("password", typing.cast(str, cred["dbPassword"]))
253260
else:

redshift_connector/idp_auth_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def set_auth_properties(info: RedshiftProperty):
6868
_logger.debug("boto3 version: {}".format(Version(pkg_resources.get_distribution("boto3").version)))
6969
_logger.debug("botocore version: {}".format(Version(pkg_resources.get_distribution("botocore").version)))
7070

71-
if info.cluster_identifier is None and not info.is_serverless_host:
71+
if info.cluster_identifier is None and not info._is_serverless:
7272
raise InterfaceError(
7373
"Invalid connection property setting. cluster_identifier must be provided when IAM is enabled"
7474
)

redshift_connector/redshift_property.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,13 @@ def __init__(self: "RedshiftProperty", **kwargs):
110110
# The user name.
111111
self.user_name: str = ""
112112
self.web_identity_token: typing.Optional[str] = None
113-
# The AWS Account Id
114-
self.account_id: typing.Optional[str] = None
115113
# The name of the Redshift Native Auth Provider
116114
self.provider_name: typing.Optional[str] = None
117115
self.scope: str = ""
118116
self.numeric_to_float: bool = False
119-
# The work group used for Amazon serverless
120-
self.work_group: typing.Optional[str] = None
117+
self.is_serverless: bool = False
118+
self.serverless_acct_id: typing.Optional[str] = None
119+
self.serverless_work_group: typing.Optional[str] = None
121120

122121
else:
123122
for k, v in kwargs.items():
@@ -126,6 +125,7 @@ def __init__(self: "RedshiftProperty", **kwargs):
126125
def __str__(self: "RedshiftProperty") -> str:
127126
rp = self.__dict__
128127
rp["is_serverless_host"] = self.is_serverless_host
128+
rp["_is_serverless"] = self._is_serverless
129129
return str(rp)
130130

131131
def put_all(self, other):
@@ -158,7 +158,14 @@ def is_serverless_host(self: "RedshiftProperty") -> bool:
158158
re.fullmatch(pattern=SERVERLESS_WITH_WORKGROUP_HOST_PATTERN, string=str(self.host))
159159
)
160160

161-
def set_account_id_from_host(self: "RedshiftProperty") -> None:
161+
@property
162+
def _is_serverless(self):
163+
"""
164+
Returns True if host patches serverless pattern or if is_serverless flag set by user
165+
"""
166+
return self.is_serverless_host or self.is_serverless
167+
168+
def set_serverless_acct_id(self: "RedshiftProperty") -> None:
162169
"""
163170
Sets the AWS account id as parsed from the Redshift serverless endpoint.
164171
"""
@@ -168,7 +175,7 @@ def set_account_id_from_host(self: "RedshiftProperty") -> None:
168175
m2 = re.fullmatch(pattern=serverless_pattern, string=self.host)
169176

170177
if m2:
171-
self.put(key="account_id", value=m2.group(typing.cast(int, m2.lastindex) - 1))
178+
self.put(key="serverless_acct_id", value=m2.group(typing.cast(int, m2.lastindex) - 1))
172179
break
173180

174181
def set_region_from_host(self: "RedshiftProperty") -> None:
@@ -184,7 +191,7 @@ def set_region_from_host(self: "RedshiftProperty") -> None:
184191
self.put(key="region", value=m2.group(typing.cast(int, m2.lastindex)))
185192
break
186193

187-
def set_work_group_from_host(self: "RedshiftProperty") -> None:
194+
def set_serverless_work_group_from_host(self: "RedshiftProperty") -> None:
188195
"""
189196
Sets the work_group as parsed from the Redshift serverless endpoint.
190197
"""
@@ -193,4 +200,4 @@ def set_work_group_from_host(self: "RedshiftProperty") -> None:
193200
m2 = re.fullmatch(pattern=SERVERLESS_WITH_WORKGROUP_HOST_PATTERN, string=self.host)
194201

195202
if m2:
196-
self.put(key="work_group", value=m2.group(1))
203+
self.put(key="serverless_work_group", value=m2.group(1))

test/unit/test_iam_helper.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,9 @@ def test_set_iam_credentials_for_serverless_calls_get_credentials(
364364
for k, v in serverless_iam_db_kwargs.items():
365365
rp.put(k, v)
366366

367-
rp.set_account_id_from_host()
367+
rp.set_serverless_acct_id()
368368
rp.set_region_from_host()
369-
rp.set_work_group_from_host()
369+
rp.set_serverless_work_group_from_host()
370370

371371
mock_cred_provider = MagicMock()
372372
mock_cred_holder = MagicMock()
@@ -378,8 +378,10 @@ def test_set_iam_credentials_for_serverless_calls_get_credentials(
378378
IamHelper.set_cluster_credentials(mock_cred_provider, rp)
379379

380380
# ensure describe_configuration is called
381-
if rp.work_group:
382-
mock_boto_client.assert_has_calls([call().get_credentials(dbName=rp.db_name, workgroupName=rp.work_group)])
381+
if rp.serverless_work_group:
382+
mock_boto_client.assert_has_calls(
383+
[call().get_credentials(dbName=rp.db_name, workgroupName=rp.serverless_work_group)]
384+
)
383385
else:
384386
mock_boto_client.assert_has_calls([call().get_credentials(dbName=rp.db_name)])
385387

@@ -391,6 +393,48 @@ def test_set_iam_credentials_for_serverless_calls_get_credentials(
391393
assert "password" in [c[0][0] for c in spy.call_args_list]
392394

393395

396+
def test_serverless_properties_used_when_is_serverless_true(mocker):
397+
rp: RedshiftProperty = RedshiftProperty()
398+
rp.host = "test-endpoint-xxxx.123456789123.us-east-2.redshift-serverless.amazonaws.com"
399+
rp.is_serverless = True
400+
rp.serverless_work_group = "something"
401+
rp.serverless_acct_id = "111111111111"
402+
403+
mocker.patch(
404+
"redshift_connector.native_plugin_helper.NativeAuthPluginHelper.get_native_auth_plugin_credentials",
405+
return_value=None,
406+
)
407+
408+
result = IamHelper.set_iam_properties(rp)
409+
410+
assert result.is_serverless == True
411+
assert result.serverless_work_group == rp.serverless_work_group
412+
assert result.serverless_acct_id == rp.serverless_acct_id
413+
414+
415+
def test_internal_is_serverless_prop_true_for_nlb_host():
416+
rp: RedshiftProperty = RedshiftProperty()
417+
rp.is_serverless = True
418+
419+
assert rp._is_serverless is True
420+
421+
422+
@pytest.mark.parametrize(
423+
"serverless_host",
424+
(
425+
"testwg1.012345678901.us-east-2.redshift-serverless.amazonaws.com",
426+
"012345678901.us-east-2.redshift-serverless.amazonaws.com",
427+
),
428+
)
429+
def test_internal_is_serverless_prop_true_for_serverless_host(serverless_host):
430+
rp: RedshiftProperty = RedshiftProperty()
431+
rp.host = serverless_host
432+
rp.serverless_work_group = "something"
433+
rp.serverless_acct_id = "111111111111"
434+
435+
assert rp._is_serverless is True
436+
437+
394438
def test_dynamically_loading_credential_holder(mocker):
395439
external_class_name: str = "test.unit.MockCredentialsProvider"
396440
mocker.patch("{}.get_credentials".format(external_class_name))
@@ -812,7 +856,7 @@ def test_set_iam_properties_calls_set_auth_props(mocker):
812856
spy = mocker.spy(IdpAuthHelper, "set_auth_properties")
813857
mock_rp: MagicMock = MagicMock()
814858
mock_rp.credentials_provider = None
815-
mock_rp.is_serverless_host = False
859+
mock_rp._is_serverless = False
816860
IamHelper.set_iam_properties(mock_rp)
817861

818862
assert spy.called is True

test/unit/test_redshift_property.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ def test_is_serverless_host(host, exp_is_serverless):
2525
("012345678901.ap-northeast-3.redshift-serverless.amazonaws.com", "012345678901"),
2626
],
2727
)
28-
def test_set_account_id_from_host(host, exp_account_id):
28+
def test_set_serverless_acct_id_from_host(host, exp_account_id):
2929
info: RedshiftProperty = RedshiftProperty()
3030
info.host = host
31-
info.set_account_id_from_host()
32-
assert info.account_id == exp_account_id
31+
info.set_serverless_acct_id()
32+
assert info.serverless_acct_id == exp_account_id
3333

3434

3535
@pytest.mark.parametrize(
@@ -55,8 +55,8 @@ def test_set_region_from_host(host, exp_region):
5555
("testwg2.012345678901.ap-northeast-3.redshift-serverless.amazonaws.com", "testwg2"),
5656
],
5757
)
58-
def test_set_work_group_from_host(host, exp_work_group):
58+
def test_set_serverless_work_group_from_host(host, exp_work_group):
5959
info: RedshiftProperty = RedshiftProperty()
6060
info.host = host
61-
info.set_work_group_from_host()
62-
assert info.work_group == exp_work_group
61+
info.set_serverless_work_group_from_host()
62+
assert info.serverless_work_group == exp_work_group

tutorials/001 - Connecting to Amazon Redshift.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@
495495
"source": [
496496
"# Connecting to a Redshift Serverless Endpoint\n",
497497
"\n",
498-
"Authentication methods discussed below are supported for Redshift serverless endpoints.\n",
498+
"Authentication methods discussed below are supported for Redshift serverless endpoints. If connecting using a network load balancer (NLB) or Redshift-managed VPC endpoint please set ``is_serverless=True`` and specify workgroup and serverless account id information using ``serverless_acct_id`` and ``serverless_work_group``, respectively.\n",
499499
"\n",
500500
"### Using Database credentials (Native authentication)\n",
501501
"\n",
@@ -584,4 +584,4 @@
584584
},
585585
"nbformat": 4,
586586
"nbformat_minor": 1
587-
}
587+
}

0 commit comments

Comments
 (0)