Skip to content

Commit 33a6925

Browse files
formatting (black)
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 08e4662 commit 33a6925

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+418
-1102
lines changed

src/databricks/sql/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def filter(self, record):
3838
)
3939
else:
4040
record.args = tuple(
41-
(self.redact(arg) if isinstance(arg, str) else arg)
42-
for arg in record.args
41+
(self.redact(arg) if isinstance(arg, str) else arg) for arg in record.args
4342
)
4443

4544
return True

src/databricks/sql/auth/authenticators.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ def __init__(
6868
try:
6969
idp_endpoint = get_oauth_endpoints(hostname, auth_type == "azure-oauth")
7070
if not idp_endpoint:
71-
raise NotImplementedError(
72-
f"OAuth is not supported for host ${hostname}"
73-
)
71+
raise NotImplementedError(f"OAuth is not supported for host ${hostname}")
7472

7573
# Convert to the corresponding scopes in the corresponding IdP
7674
cloud_scopes = idp_endpoint.get_scopes_mapping(scopes)
@@ -179,9 +177,7 @@ class AzureServicePrincipalCredentialProvider(CredentialsProvider):
179177
AZURE_MANAGED_RESOURCE = "https://management.core.windows.net/"
180178

181179
DATABRICKS_AZURE_SP_TOKEN_HEADER = "X-Databricks-Azure-SP-Management-Token"
182-
DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = (
183-
"X-Databricks-Azure-Workspace-Resource-Id"
184-
)
180+
DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = "X-Databricks-Azure-Workspace-Resource-Id"
185181

186182
def __init__(
187183
self,
@@ -195,9 +191,7 @@ def __init__(
195191
self.azure_client_id = azure_client_id
196192
self.azure_client_secret = azure_client_secret
197193
self.azure_workspace_resource_id = azure_workspace_resource_id
198-
self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
199-
hostname
200-
)
194+
self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(hostname)
201195

202196
def auth_type(self) -> str:
203197
return AuthType.AZURE_SP_M2M.value
@@ -211,9 +205,7 @@ def get_token_source(self, resource: str) -> RefreshableTokenSource:
211205
)
212206

213207
def __call__(self, *args, **kwargs) -> HeaderFactory:
214-
inner = self.get_token_source(
215-
resource=get_effective_azure_login_app_id(self.hostname)
216-
)
208+
inner = self.get_token_source(resource=get_effective_azure_login_app_id(self.hostname))
217209
cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)
218210

219211
def header_factory() -> Dict[str, str]:

src/databricks/sql/auth/endpoint.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@ def infer_cloud_from_host(hostname: str) -> Optional[CloudType]:
5858

5959
def is_supported_databricks_oauth_host(hostname: str) -> bool:
6060
host = hostname.lower().replace("https://", "").split("/")[0]
61-
domains = (
62-
DATABRICKS_AWS_DOMAINS + DATABRICKS_GCP_DOMAINS + DATABRICKS_OAUTH_AZURE_DOMAINS
63-
)
61+
domains = DATABRICKS_AWS_DOMAINS + DATABRICKS_GCP_DOMAINS + DATABRICKS_OAUTH_AZURE_DOMAINS
6462
return any(e for e in domains if host.endswith(e))
6563

6664

@@ -106,7 +104,9 @@ def get_authorization_url(self, hostname: str):
106104
return f"{get_databricks_oidc_url(hostname)}/oauth2/v2.0/authorize"
107105

108106
def get_openid_config_url(self, hostname: str):
109-
return "https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration"
107+
return (
108+
"https://login.microsoftonline.com/organizations/v2.0/.well-known/openid-configuration"
109+
)
110110

111111

112112
class InHouseOAuthEndpointCollection(OAuthEndpointCollection):
@@ -123,9 +123,7 @@ def get_openid_config_url(self, hostname: str):
123123
return f"{idp_url}/.well-known/oauth-authorization-server"
124124

125125

126-
def get_oauth_endpoints(
127-
hostname: str, use_azure_auth: bool
128-
) -> Optional[OAuthEndpointCollection]:
126+
def get_oauth_endpoints(hostname: str, use_azure_auth: bool) -> Optional[OAuthEndpointCollection]:
129127
cloud = infer_cloud_from_host(hostname)
130128

131129
if cloud in [CloudType.AWS, CloudType.GCP]:

src/databricks/sql/auth/oauth.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ def __init__(self, access_token: str, token_type: str, refresh_token: str):
4141

4242
def is_expired(self) -> bool:
4343
try:
44-
decoded_token = jwt.decode(
45-
self.access_token, options={"verify_signature": False}
46-
)
44+
decoded_token = jwt.decode(self.access_token, options={"verify_signature": False})
4745
exp_time = decoded_token.get("exp")
4846
current_time = time.time()
4947
buffer_time = 30 # 30 seconds buffer
@@ -134,9 +132,7 @@ def __fetch_well_known_config(self, hostname: str):
134132
def __get_challenge():
135133
verifier_string = OAuthManager.__token_urlsafe(32)
136134
digest = hashlib.sha256(verifier_string.encode("UTF-8")).digest()
137-
challenge_string = (
138-
base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", "")
139-
)
135+
challenge_string = base64.urlsafe_b64encode(digest).decode("UTF-8").replace("=", "")
140136
return verifier_string, challenge_string
141137

142138
def __get_authorization_code(self, client, auth_url, scope, state, challenge):
@@ -158,9 +154,7 @@ def __get_authorization_code(self, client, auth_url, scope, state, challenge):
158154
logger.info(f"Opening {auth_req_uri}")
159155

160156
webbrowser.open_new(auth_req_uri)
161-
logger.info(
162-
f"Listening for OAuth authorization callback at {redirect_url}"
163-
)
157+
logger.info(f"Listening for OAuth authorization callback at {redirect_url}")
164158
httpd.handle_request()
165159
self.redirect_port = port
166160
break
@@ -182,9 +176,7 @@ def __get_authorization_code(self, client, auth_url, scope, state, challenge):
182176
raise RuntimeError(msg)
183177
# This is a kludge because the parsing library expects https callbacks
184178
# We should probably set it up using https
185-
full_redirect_url = (
186-
f"https://localhost:{self.redirect_port}/{handler.request_path}"
187-
)
179+
full_redirect_url = f"https://localhost:{self.redirect_port}/{handler.request_path}"
188180
try:
189181
authorization_code_response = client.parse_request_uri_response(
190182
full_redirect_url, state=state
@@ -197,9 +189,7 @@ def __get_authorization_code(self, client, auth_url, scope, state, challenge):
197189
def __send_auth_code_token_request(
198190
self, client, token_request_url, redirect_url, code, verifier
199191
):
200-
token_request_body = client.prepare_request_body(
201-
code=code, redirect_uri=redirect_url
202-
)
192+
token_request_body = client.prepare_request_body(code=code, redirect_uri=redirect_url)
203193
data = f"{token_request_body}&code_verifier={verifier}"
204194
return self.__send_token_request(token_request_url, data)
205195

@@ -227,15 +217,11 @@ def __send_refresh_token_request(self, hostname, refresh_token):
227217
def __get_tokens_from_response(oauth_response):
228218
access_token = oauth_response["access_token"]
229219
refresh_token = (
230-
oauth_response["refresh_token"]
231-
if "refresh_token" in oauth_response
232-
else None
220+
oauth_response["refresh_token"] if "refresh_token" in oauth_response else None
233221
)
234222
return access_token, refresh_token
235223

236-
def check_and_refresh_access_token(
237-
self, hostname: str, access_token: str, refresh_token: str
238-
):
224+
def check_and_refresh_access_token(self, hostname: str, access_token: str, refresh_token: str):
239225
now = datetime.now(tz=timezone.utc)
240226
# If we can't decode an expiration time, this will be expired by default.
241227
expiration_time = now
@@ -246,9 +232,7 @@ def check_and_refresh_access_token(
246232
# an unnecessary signature verification.
247233
access_token_payload = access_token.split(".")[1]
248234
# add padding
249-
access_token_payload = access_token_payload + "=" * (
250-
-len(access_token_payload) % 4
251-
)
235+
access_token_payload = access_token_payload + "=" * (-len(access_token_payload) % 4)
252236
decoded = json.loads(base64.standard_b64decode(access_token_payload))
253237
expiration_time = datetime.fromtimestamp(decoded["exp"], tz=timezone.utc)
254238
except Exception as e:
@@ -265,13 +249,9 @@ def check_and_refresh_access_token(
265249
raise RuntimeError(msg)
266250

267251
# Try to refresh using the refresh token
268-
logger.debug(
269-
f"Attempting to refresh OAuth access token that expired on {expiration_time}"
270-
)
252+
logger.debug(f"Attempting to refresh OAuth access token that expired on {expiration_time}")
271253
oauth_response = self.__send_refresh_token_request(hostname, refresh_token)
272-
fresh_access_token, fresh_refresh_token = self.__get_tokens_from_response(
273-
oauth_response
274-
)
254+
fresh_access_token, fresh_refresh_token = self.__get_tokens_from_response(oauth_response)
275255
return fresh_access_token, fresh_refresh_token, True
276256

277257
def get_tokens(self, hostname: str, scope=None):
@@ -285,9 +265,7 @@ def get_tokens(self, hostname: str, scope=None):
285265
client = oauthlib.oauth2.WebApplicationClient(self.client_id)
286266

287267
try:
288-
auth_response = self.__get_authorization_code(
289-
client, auth_url, scope, state, challenge
290-
)
268+
auth_response = self.__get_authorization_code(client, auth_url, scope, state, challenge)
291269
except OAuth2Error as e:
292270
msg = f"OAuth Authorization Error: {e.description}"
293271
logger.error(msg)
@@ -359,6 +337,4 @@ def refresh(self) -> Token:
359337
oauth_response.refresh_token,
360338
)
361339
else:
362-
raise Exception(
363-
f"Failed to get token: {response.status_code} {response.text}"
364-
)
340+
raise Exception(f"Failed to get token: {response.status_code} {response.text}")

src/databricks/sql/auth/retry.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,7 @@ def __private_init__(
167167
new_object.command_type = command_type
168168
return new_object
169169

170-
def new(
171-
self, **urllib3_incremented_counters: typing.Any
172-
) -> "DatabricksRetryPolicy":
170+
def new(self, **urllib3_incremented_counters: typing.Any) -> "DatabricksRetryPolicy":
173171
"""This method is responsible for passing the entire Retry state to its next iteration.
174172
175173
urllib3 calls Retry.new() between successive requests as part of its `.increment()` method
@@ -435,9 +433,7 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
435433
"Failed requests are retried by default per configured DatabricksRetryPolicy",
436434
)
437435

438-
def is_retry(
439-
self, method: str, status_code: int, has_retry_after: bool = False
440-
) -> bool:
436+
def is_retry(self, method: str, status_code: int, has_retry_after: bool = False) -> bool:
441437
"""
442438
Called by urllib3 when determining whether or not to retry
443439

src/databricks/sql/auth/thrift_http_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,7 @@ def flush(self):
199199
self.headers = self.__resp.headers
200200

201201
logger.info(
202-
"HTTP Response with status code {}, message: {}".format(
203-
self.code, self.message
204-
)
202+
"HTTP Response with status code {}, message: {}".format(self.code, self.message)
205203
)
206204

207205
@staticmethod

0 commit comments

Comments
 (0)