Skip to content

Commit 738f2ad

Browse files
committed
feat: support Redshift native authentication, Add Azure Oauth2 IdP
1 parent 28266c3 commit 738f2ad

13 files changed

+891
-372
lines changed

redshift_connector/__init__.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import typing
33

44
from redshift_connector import plugin
5-
from redshift_connector.config import DEFAULT_PROTOCOL_VERSION
5+
from redshift_connector.config import DEFAULT_PROTOCOL_VERSION, ClientProtocolVersion
66
from redshift_connector.core import BINARY, Connection, Cursor
77
from redshift_connector.error import (
88
ArrayContentNotHomogenousError,
@@ -177,6 +177,8 @@ def connect(
177177
iam_disable_cache: typing.Optional[bool] = None,
178178
auth_profile: typing.Optional[str] = None,
179179
endpoint_url: typing.Optional[str] = None,
180+
provider_name: typing.Optional[str] = None,
181+
scope: typing.Optional[str] = None,
180182
) -> Connection:
181183
"""
182184
Establishes a :class:`Connection` to an Amazon Redshift cluster. This function validates user input, optionally authenticates using an identity provider plugin, then constructs a :class:`Connection` object.
@@ -266,6 +268,10 @@ def connect(
266268
The name of an Amazon Redshift Authentication profile having connection properties as JSON. See :class:RedshiftProperty to learn how connection properties should be named.
267269
endpoint_url: Optional[str]
268270
The Amazon Redshift endpoint url. This option is only used by AWS internal teams.
271+
provider_name: Optional[str]
272+
The name of the Redshift Native Auth Provider.
273+
scope: Optional[str]
274+
Scope for BrowserAzureOauth2CredentialsProvider authentication.
269275
Returns
270276
-------
271277
A Connection object associated with the specified Amazon Redshift cluster: :class:`Connection`
@@ -304,10 +310,12 @@ def connect(
304310
info.put("preferred_role", preferred_role)
305311
info.put("principal", principal_arn)
306312
info.put("profile", profile)
313+
info.put("provider_name", provider_name)
307314
info.put("region", region)
308315
info.put("replication", replication)
309316
info.put("role_arn", role_arn)
310317
info.put("role_session_name", role_session_name)
318+
info.put("scope", scope)
311319
info.put("secret_access_key", secret_access_key)
312320
info.put("session_token", session_token)
313321
info.put("source_address", source_address)
@@ -326,7 +334,27 @@ def connect(
326334
_logger.debug(mask_secure_info_in_props(info).__str__())
327335
_logger.debug(make_divider_block())
328336

329-
IamHelper.set_iam_properties(info)
337+
if (info.ssl is False) and (info.iam is True):
338+
raise InterfaceError("Invalid connection property setting. SSL must be enabled when using IAM")
339+
340+
if (info.iam is False) and (info.ssl_insecure is False):
341+
raise InterfaceError("Invalid connection property setting. IAM must be enabled when using ssl_insecure")
342+
343+
if info.client_protocol_version not in ClientProtocolVersion.list():
344+
raise InterfaceError(
345+
"Invalid connection property setting. client_protocol_version must be in: {}".format(
346+
ClientProtocolVersion.list()
347+
)
348+
)
349+
350+
redshift_native_auth: bool = False
351+
if info.iam:
352+
if info.credentials_provider == "BasicJwtCredentialsProvider":
353+
redshift_native_auth = True
354+
_logger.debug("redshift_native_auth enabled")
355+
356+
if not redshift_native_auth:
357+
IamHelper.set_iam_properties(info)
330358

331359
_logger.debug(make_divider_block())
332360
_logger.debug("Connection arguments following validation and IAM auth (if applicable)")
@@ -352,6 +380,8 @@ def connect(
352380
client_protocol_version=info.client_protocol_version,
353381
database_metadata_current_db_only=info.database_metadata_current_db_only,
354382
credentials_provider=info.credentials_provider,
383+
provider_name=info.provider_name,
384+
web_identity_token=info.web_identity_token,
355385
)
356386

357387

redshift_connector/core.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,8 @@ def __init__(
417417
client_protocol_version: int = DEFAULT_PROTOCOL_VERSION,
418418
database_metadata_current_db_only: bool = True,
419419
credentials_provider: typing.Optional[str] = None,
420+
provider_name: typing.Optional[str] = None,
421+
web_identity_token: typing.Optional[str] = None,
420422
):
421423
"""
422424
Creates a :class:`Connection` to an Amazon Redshift cluster. For more information on establishing a connection to an Amazon Redshift cluster using `federated API access <https://aws.amazon.com/blogs/big-data/federated-api-access-to-amazon-redshift-using-an-amazon-redshift-connector-for-python/>`_ see our examples page.
@@ -455,6 +457,10 @@ def __init__(
455457
Is `datashare <https://docs.aws.amazon.com/redshift/latest/dg/datashare-overview.html>`_ disabled. Default value is True, implying datasharing will not be used.
456458
credentials_provider : Optional[str]
457459
The class-path of the IdP plugin used for authentication with Amazon Redshift.
460+
provider_name : Optional[str]
461+
The name of the Redshift Native Auth Provider.
462+
web_identity_token: Optional[str]
463+
A web identity token used for authentication via Redshift Native IDP Integration
458464
"""
459465
self.merge_socket_read = True
460466

@@ -483,11 +489,15 @@ def __init__(
483489
# for receiving some datatypes
484490
self._enable_protocol_based_conversion_funcs()
485491

492+
self.web_identity_token = web_identity_token
493+
486494
if user is None:
487495
raise InterfaceError("The 'user' connection parameter cannot be None")
488496

497+
redshift_native_auth: bool = False
498+
489499
init_params: typing.Dict[str, typing.Optional[typing.Union[str, bytes]]] = {
490-
"user": user,
500+
"user": "",
491501
"database": database,
492502
"application_name": application_name,
493503
"replication": replication,
@@ -499,6 +509,19 @@ def __init__(
499509
if credentials_provider:
500510
init_params["plugin_name"] = credentials_provider
501511

512+
if credentials_provider.split(".")[-1] in (
513+
"BasicJwtCredentialsProvider",
514+
"BrowserAzureOAuth2CredentialsProvider",
515+
):
516+
redshift_native_auth = True
517+
init_params["idp_type"] = "AzureAD"
518+
519+
if provider_name:
520+
init_params["provider_name"] = provider_name
521+
522+
if not redshift_native_auth or user:
523+
init_params["user"] = user
524+
502525
_logger.debug(make_divider_block())
503526
_logger.debug("Establishing a connection")
504527
_logger.debug(init_params)
@@ -512,7 +535,10 @@ def __init__(
512535
elif not isinstance(v, (bytes, bytearray)):
513536
raise InterfaceError("The parameter " + k + " can't be of type " + str(type(v)) + ".")
514537

515-
self.user: bytes = typing.cast(bytes, init_params["user"])
538+
if "user" in init_params:
539+
self.user: bytes = typing.cast(bytes, init_params["user"])
540+
else:
541+
self.user = b""
516542

517543
if isinstance(password, str):
518544
self.password: bytes = password.encode("utf8")
@@ -1217,6 +1243,7 @@ def handle_AUTHENTICATION_REQUEST(self: "Connection", data: bytes, cursor: Curso
12171243
7 = GSSAPI (not supported)
12181244
8 = GSSAPI data (not supported)
12191245
9 = SSPI (not supported)
1246+
14 = Redshift Native IDP Integration
12201247
12211248
Please note that some authentication messages have additional data following the authentication code.
12221249
That data is documented in the appropriate conditional branch below.
@@ -1284,6 +1311,22 @@ def handle_AUTHENTICATION_REQUEST(self: "Connection", data: bytes, cursor: Curso
12841311
elif auth_code == 12:
12851312
# AuthenticationSASLFinal
12861313
self.auth.set_server_final(data[4:].decode("utf8"))
1314+
elif auth_code == 14:
1315+
# Redshift Native IDP Integration
1316+
aad_token: str = typing.cast(str, self.web_identity_token)
1317+
_logger.debug("<=BE Authentication request IDP")
1318+
1319+
if not aad_token:
1320+
raise ConnectionAbortedError(
1321+
"The server requested AAD token-based authentication, but no token was provided."
1322+
)
1323+
1324+
_logger.debug("FE=> IDP(AAD Token)")
1325+
1326+
token: bytes = aad_token.encode(encoding="utf-8")
1327+
self._write(create_message(b"i", token))
1328+
# self._write(NULL_BYTE)
1329+
self._flush()
12871330

12881331
elif auth_code in (2, 4, 6, 7, 8, 9):
12891332
raise InterfaceError("Authentication method " + str(auth_code) + " not supported by redshift_connector.")

0 commit comments

Comments
 (0)