|
| 1 | +import typing |
| 2 | + |
| 3 | +import pytest |
| 4 | +from pytest_mock import mocker # type: ignore |
| 5 | + |
| 6 | +from redshift_connector.error import InterfaceError |
| 7 | +from redshift_connector.plugin.browser_azure_oauth2_credentials_provider import ( |
| 8 | + BrowserAzureOAuth2CredentialsProvider, |
| 9 | +) |
| 10 | +from redshift_connector.redshift_property import RedshiftProperty |
| 11 | + |
| 12 | + |
| 13 | +def make_valid_azure_oauth2_provider() -> typing.Tuple[BrowserAzureOAuth2CredentialsProvider, RedshiftProperty]: |
| 14 | + rp: RedshiftProperty = RedshiftProperty() |
| 15 | + rp.idp_tenant = "my_idp_tenant" |
| 16 | + rp.client_id = "my_client_id" |
| 17 | + rp.scope = "my_scope" |
| 18 | + rp.idp_response_timeout = 900 |
| 19 | + rp.listen_port = 1099 |
| 20 | + cp: BrowserAzureOAuth2CredentialsProvider = BrowserAzureOAuth2CredentialsProvider() |
| 21 | + cp.add_parameter(rp) |
| 22 | + return cp, rp |
| 23 | + |
| 24 | + |
| 25 | +def test_add_parameter_sets_azure_oauth2_specific(): |
| 26 | + acp, rp = make_valid_azure_oauth2_provider() |
| 27 | + assert acp.idp_tenant == rp.idp_tenant |
| 28 | + assert acp.client_id == rp.client_id |
| 29 | + assert acp.scope == rp.scope |
| 30 | + assert acp.idp_response_timeout == rp.idp_response_timeout |
| 31 | + assert acp.listen_port == rp.listen_port |
| 32 | + |
| 33 | + |
| 34 | +@pytest.mark.parametrize("value", [None, ""]) |
| 35 | +def test_check_required_parameters_raises_if_idp_tenant_missing_or_too_small(value): |
| 36 | + acp, _ = make_valid_azure_oauth2_provider() |
| 37 | + acp.idp_tenant = value |
| 38 | + |
| 39 | + with pytest.raises(InterfaceError, match="BrowserAzureOauth2CredentialsProvider requires idp_tenant"): |
| 40 | + acp.get_jwt_assertion() |
| 41 | + |
| 42 | + |
| 43 | +@pytest.mark.parametrize("value", [None, ""]) |
| 44 | +def test_check_required_parameters_raises_if_client_id_missing(value): |
| 45 | + acp, _ = make_valid_azure_oauth2_provider() |
| 46 | + acp.client_id = value |
| 47 | + |
| 48 | + with pytest.raises(InterfaceError, match="BrowserAzureOauth2CredentialsProvider requires client_id"): |
| 49 | + acp.get_jwt_assertion() |
| 50 | + |
| 51 | + |
| 52 | +@pytest.mark.parametrize("value", [None, ""]) |
| 53 | +def test_check_required_parameters_raises_if_idp_response_timeout_missing(value): |
| 54 | + acp, _ = make_valid_azure_oauth2_provider() |
| 55 | + acp.idp_response_timeout = value |
| 56 | + |
| 57 | + with pytest.raises(InterfaceError, match="BrowserAzureOauth2CredentialsProvider requires idp_response_timeout"): |
| 58 | + acp.get_jwt_assertion() |
| 59 | + |
| 60 | + |
| 61 | +def test_get_jwt_assertion_fetches_and_extracts(mocker): |
| 62 | + mock_token: str = "mock_token" |
| 63 | + mock_content: str = "mock_content" |
| 64 | + mock_jwt_assertion: str = "mock_jwt_assertion" |
| 65 | + mocker.patch( |
| 66 | + "redshift_connector.plugin.browser_azure_oauth2_credentials_provider." |
| 67 | + "BrowserAzureOAuth2CredentialsProvider.fetch_authorization_token", |
| 68 | + return_value=mock_token, |
| 69 | + ) |
| 70 | + mocker.patch( |
| 71 | + "redshift_connector.plugin.browser_azure_oauth2_credentials_provider." |
| 72 | + "BrowserAzureOAuth2CredentialsProvider.fetch_jwt_response", |
| 73 | + return_value=mock_content, |
| 74 | + ) |
| 75 | + mocker.patch( |
| 76 | + "redshift_connector.plugin.browser_azure_oauth2_credentials_provider." |
| 77 | + "BrowserAzureOAuth2CredentialsProvider.extract_jwt_assertion", |
| 78 | + return_value=mock_jwt_assertion, |
| 79 | + ) |
| 80 | + acp, rp = make_valid_azure_oauth2_provider() |
| 81 | + |
| 82 | + fetch_token_spy = mocker.spy(acp, "fetch_authorization_token") |
| 83 | + fetch_jwt_spy = mocker.spy(acp, "fetch_jwt_response") |
| 84 | + extract_jwt_spy = mocker.spy(acp, "extract_jwt_assertion") |
| 85 | + |
| 86 | + jwt_assertion: str = acp.get_jwt_assertion() |
| 87 | + |
| 88 | + assert fetch_token_spy.called is True |
| 89 | + assert fetch_token_spy.call_count == 1 |
| 90 | + |
| 91 | + assert fetch_jwt_spy.called is True |
| 92 | + assert fetch_jwt_spy.call_count == 1 |
| 93 | + assert fetch_jwt_spy.call_args[0][0] == mock_token |
| 94 | + |
| 95 | + assert extract_jwt_spy.called is True |
| 96 | + assert extract_jwt_spy.call_count == 1 |
| 97 | + assert extract_jwt_spy.call_args[0][0] == mock_content |
| 98 | + |
| 99 | + assert jwt_assertion == mock_jwt_assertion |
0 commit comments