Skip to content

Commit 7c33fe4

Browse files
more fixes
Signed-off-by: Vikrant Puppala <vikrant.puppala@databricks.com>
1 parent d1f045e commit 7c33fe4

File tree

9 files changed

+100
-83
lines changed

9 files changed

+100
-83
lines changed

src/databricks/sql/auth/authenticators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ def __init__(
199199
self.azure_client_secret = azure_client_secret
200200
self.azure_workspace_resource_id = azure_workspace_resource_id
201201
self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
202-
hostname
202+
hostname,
203+
http_client
203204
)
204205
self._http_client = http_client
205206

src/databricks/sql/auth/common.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,15 @@ def get_azure_tenant_id_from_host(host: str, http_client) -> str:
116116
logger.debug("Loading tenant ID from %s", login_url)
117117

118118
with http_client.request_context(
119-
HttpMethod.GET, login_url, allow_redirects=False
119+
HttpMethod.GET, login_url
120120
) as resp:
121-
if resp.status // 100 != 3:
122-
raise ValueError(
123-
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}"
124-
)
125-
entra_id_endpoint = dict(resp.headers).get("Location")
121+
# if resp.status // 100 != 3:
122+
# raise ValueError(
123+
# f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status}"
124+
# )
125+
entra_id_endpoint = resp.retries.history[-1].redirect_location
126126
if entra_id_endpoint is None:
127-
raise ValueError(f"No Location header in response from {login_url}")
127+
raise ValueError(f"No Location header in response from {login_url}: {entra_id_endpoint}")
128128

129129
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
130130
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).

src/databricks/sql/auth/oauth.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -336,18 +336,23 @@ def refresh(self) -> Token:
336336
**self.extra_params,
337337
}
338338
)
339+
339340

340-
with self._http_client.execute(
341-
method=HttpMethod.POST, url=self.token_url, headers=headers, data=data
342-
) as response:
343-
if response.status_code == 200:
344-
oauth_response = OAuthResponse(**response.json())
341+
response = self._http_client.request(
342+
method=HttpMethod.POST, url=self.token_url, headers=headers, body=data
343+
)
344+
try:
345+
if response.status == 200:
346+
import json
347+
oauth_response = OAuthResponse(**json.loads(response.data.decode('utf-8')))
345348
return Token(
346349
oauth_response.access_token,
347350
oauth_response.token_type,
348351
oauth_response.refresh_token,
349352
)
350353
else:
351354
raise Exception(
352-
f"Failed to get token: {response.status_code} {response.text}"
355+
f"Failed to get token: {response.status} {response.data.decode('utf-8')}"
353356
)
357+
finally:
358+
response.close()

src/databricks/sql/auth/retry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]:
355355
logger.info(f"Received status code {status_code} for {method} request")
356356

357357
# Request succeeded. Don't retry.
358-
if status_code // 100 == 2:
358+
if status_code // 100 <= 3:
359359
return False, "2xx codes are not retried"
360360

361361
if status_code == 401:

src/databricks/sql/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def read(self) -> Optional[OAuthToken]:
279279
host_url=server_hostname,
280280
http_path=http_path,
281281
port=kwargs.get("_port", 443),
282-
http_client=self.http_client,
282+
client_context=client_context,
283283
user_agent=self.session.useragent_header
284284
if hasattr(self, "session")
285285
else None,
@@ -301,7 +301,7 @@ def read(self) -> Optional[OAuthToken]:
301301
auth_provider=self.session.auth_provider,
302302
host_url=self.session.host,
303303
batch_size=self.telemetry_batch_size,
304-
http_client=self.http_client,
304+
client_context=client_context,
305305
)
306306

307307
self._telemetry_client = TelemetryClientFactory.get_telemetry_client(

src/databricks/sql/common/unified_http_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def request_context(
154154
Yields:
155155
urllib3.HTTPResponse: The HTTP response object
156156
"""
157-
logger.debug("Making %s request to %s", method, url)
157+
logger.debug("Making %s request to %s", method, urllib.parse.urlparse(url).netloc)
158158

159159
request_headers = self._prepare_headers(headers)
160160

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def __init__(
172172
host_url,
173173
executor,
174174
batch_size,
175-
http_client,
175+
client_context,
176176
):
177177
logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex)
178178
self._telemetry_enabled = telemetry_enabled
@@ -186,8 +186,10 @@ def __init__(
186186
self._host_url = host_url
187187
self._executor = executor
188188

189-
# Use the provided HTTP client directly
190-
self._http_client = http_client
189+
# Create own HTTP client from client context
190+
from databricks.sql.common.unified_http_client import UnifiedHttpClient
191+
192+
self._http_client = UnifiedHttpClient(client_context)
191193

192194
def _export_event(self, event):
193195
"""Add an event to the batch queue and flush if batch is full"""
@@ -456,7 +458,7 @@ def initialize_telemetry_client(
456458
auth_provider,
457459
host_url,
458460
batch_size,
459-
http_client,
461+
client_context,
460462
):
461463
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
462464
try:
@@ -479,7 +481,7 @@ def initialize_telemetry_client(
479481
host_url=host_url,
480482
executor=TelemetryClientFactory._executor,
481483
batch_size=batch_size,
482-
http_client=http_client,
484+
client_context=client_context,
483485
)
484486
else:
485487
TelemetryClientFactory._clients[
@@ -532,10 +534,10 @@ def connection_failure_log(
532534
host_url: str,
533535
http_path: str,
534536
port: int,
535-
http_client,
537+
client_context,
536538
user_agent: Optional[str] = None,
537539
):
538-
"""Send error telemetry when connection creation fails, using existing HTTP client"""
540+
"""Send error telemetry when connection creation fails, using provided client context"""
539541

540542
UNAUTH_DUMMY_SESSION_ID = "unauth_session_id"
541543

@@ -545,7 +547,7 @@ def connection_failure_log(
545547
auth_provider=None,
546548
host_url=host_url,
547549
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
548-
http_client=http_client,
550+
client_context=client_context,
549551
)
550552

551553
telemetry_client = TelemetryClientFactory.get_telemetry_client(

tests/e2e/test_concurrent_telemetry.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
from unittest.mock import patch
66
import pytest
7+
import json
78

89
from databricks.sql.telemetry.models.enums import StatementType
910
from databricks.sql.telemetry.telemetry_client import (
@@ -119,8 +120,12 @@ def execute_query_worker(thread_id):
119120
for future in done:
120121
try:
121122
response = future.result()
122-
response.raise_for_status()
123-
captured_responses.append(response.json())
123+
# Check status using urllib3 method (response.status instead of response.raise_for_status())
124+
if response.status >= 400:
125+
raise Exception(f"HTTP {response.status}: {getattr(response, 'reason', 'Unknown')}")
126+
# Parse JSON using urllib3 method (response.data.decode() instead of response.json())
127+
response_data = json.loads(response.data.decode()) if response.data else {}
128+
captured_responses.append(response_data)
124129
except Exception as e:
125130
captured_exceptions.append(e)
126131

tests/unit/test_telemetry.py

Lines changed: 59 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,19 @@ def mock_telemetry_client():
2424
session_id = str(uuid.uuid4())
2525
auth_provider = AccessTokenAuthProvider("test-token")
2626
executor = MagicMock()
27-
mock_http_client = MagicMock()
28-
29-
return TelemetryClient(
30-
telemetry_enabled=True,
31-
session_id_hex=session_id,
32-
auth_provider=auth_provider,
33-
host_url="test-host.com",
34-
executor=executor,
35-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
36-
http_client=mock_http_client,
37-
)
27+
client_context = MagicMock()
28+
29+
# Patch the _setup_pool_manager method to avoid SSL file loading
30+
with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'):
31+
return TelemetryClient(
32+
telemetry_enabled=True,
33+
session_id_hex=session_id,
34+
auth_provider=auth_provider,
35+
host_url="test-host.com",
36+
executor=executor,
37+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
38+
client_context=client_context,
39+
)
3840

3941

4042
class TestNoopTelemetryClient:
@@ -216,41 +218,42 @@ def test_client_lifecycle_flow(self):
216218
"""Test complete client lifecycle: initialize -> use -> close."""
217219
session_id_hex = "test-session"
218220
auth_provider = AccessTokenAuthProvider("token")
219-
mock_http_client = MagicMock()
221+
client_context = MagicMock()
220222

221223
# Initialize enabled client
222-
TelemetryClientFactory.initialize_telemetry_client(
223-
telemetry_enabled=True,
224-
session_id_hex=session_id_hex,
225-
auth_provider=auth_provider,
226-
host_url="test-host.com",
227-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
228-
http_client=mock_http_client,
229-
)
224+
with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'):
225+
TelemetryClientFactory.initialize_telemetry_client(
226+
telemetry_enabled=True,
227+
session_id_hex=session_id_hex,
228+
auth_provider=auth_provider,
229+
host_url="test-host.com",
230+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
231+
client_context=client_context,
232+
)
230233

231-
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
232-
assert isinstance(client, TelemetryClient)
233-
assert client._session_id_hex == session_id_hex
234+
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
235+
assert isinstance(client, TelemetryClient)
236+
assert client._session_id_hex == session_id_hex
234237

235-
# Close client
236-
with patch.object(client, "close") as mock_close:
237-
TelemetryClientFactory.close(session_id_hex)
238-
mock_close.assert_called_once()
238+
# Close client
239+
with patch.object(client, "close") as mock_close:
240+
TelemetryClientFactory.close(session_id_hex)
241+
mock_close.assert_called_once()
239242

240-
# Should get NoopTelemetryClient after close
243+
# Should get NoopTelemetryClient after close
241244

242245
def test_disabled_telemetry_creates_noop_client(self):
243246
"""Test that disabled telemetry creates NoopTelemetryClient."""
244247
session_id_hex = "test-session"
245-
mock_http_client = MagicMock()
248+
client_context = MagicMock()
246249

247250
TelemetryClientFactory.initialize_telemetry_client(
248251
telemetry_enabled=False,
249252
session_id_hex=session_id_hex,
250253
auth_provider=None,
251254
host_url="test-host.com",
252255
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
253-
http_client=mock_http_client,
256+
client_context=client_context,
254257
)
255258

256259
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
@@ -259,7 +262,7 @@ def test_disabled_telemetry_creates_noop_client(self):
259262
def test_factory_error_handling(self):
260263
"""Test that factory errors fall back to NoopTelemetryClient."""
261264
session_id = "test-session"
262-
mock_http_client = MagicMock()
265+
client_context = MagicMock()
263266

264267
# Simulate initialization error
265268
with patch(
@@ -272,7 +275,7 @@ def test_factory_error_handling(self):
272275
auth_provider=AccessTokenAuthProvider("token"),
273276
host_url="test-host.com",
274277
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
275-
http_client=mock_http_client,
278+
client_context=client_context,
276279
)
277280

278281
# Should fall back to NoopTelemetryClient
@@ -283,31 +286,32 @@ def test_factory_shutdown_flow(self):
283286
"""Test factory shutdown when last client is removed."""
284287
session1 = "session-1"
285288
session2 = "session-2"
286-
mock_http_client = MagicMock()
289+
client_context = MagicMock()
287290

288291
# Initialize multiple clients
289-
for session in [session1, session2]:
290-
TelemetryClientFactory.initialize_telemetry_client(
291-
telemetry_enabled=True,
292-
session_id_hex=session,
293-
auth_provider=AccessTokenAuthProvider("token"),
294-
host_url="test-host.com",
295-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
296-
http_client=mock_http_client,
297-
)
298-
299-
# Factory should be initialized
300-
assert TelemetryClientFactory._initialized is True
301-
assert TelemetryClientFactory._executor is not None
302-
303-
# Close first client - factory should stay initialized
304-
TelemetryClientFactory.close(session1)
305-
assert TelemetryClientFactory._initialized is True
306-
307-
# Close second client - factory should shut down
308-
TelemetryClientFactory.close(session2)
309-
assert TelemetryClientFactory._initialized is False
310-
assert TelemetryClientFactory._executor is None
292+
with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_manager'):
293+
for session in [session1, session2]:
294+
TelemetryClientFactory.initialize_telemetry_client(
295+
telemetry_enabled=True,
296+
session_id_hex=session,
297+
auth_provider=AccessTokenAuthProvider("token"),
298+
host_url="test-host.com",
299+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
300+
client_context=client_context,
301+
)
302+
303+
# Factory should be initialized
304+
assert TelemetryClientFactory._initialized is True
305+
assert TelemetryClientFactory._executor is not None
306+
307+
# Close first client - factory should stay initialized
308+
TelemetryClientFactory.close(session1)
309+
assert TelemetryClientFactory._initialized is True
310+
311+
# Close second client - factory should shut down
312+
TelemetryClientFactory.close(session2)
313+
assert TelemetryClientFactory._initialized is False
314+
assert TelemetryClientFactory._executor is None
311315

312316
@patch(
313317
"databricks.sql.telemetry.telemetry_client.TelemetryClient.export_failure_log"

0 commit comments

Comments
 (0)