Skip to content

Commit e4c05d1

Browse files
committed
removed unused params from driver connection params, initialize_telemetry_client does not return a telemetry client
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 4a2386b commit e4c05d1

File tree

4 files changed

+30
-31
lines changed

4 files changed

+30
-31
lines changed

src/databricks/sql/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,13 +303,17 @@ def read(self) -> Optional[OAuthToken]:
303303
kwargs.get("use_inline_params", False)
304304
)
305305

306-
self.telemetry_client = TelemetryClientFactory.initialize_telemetry_client(
306+
TelemetryClientFactory.initialize_telemetry_client(
307307
telemetry_enabled=self.telemetry_enabled,
308308
connection_uuid=self.get_session_id_hex(),
309309
auth_provider=auth_provider,
310310
host_url=self.host,
311311
)
312312

313+
self.telemetry_client = TelemetryClientFactory.get_telemetry_client(
314+
connection_uuid=self.get_session_id_hex()
315+
)
316+
313317
driver_connection_params = DriverConnectionParameters(
314318
http_path=http_path,
315319
mode=DatabricksClientType.THRIFT,

src/databricks/sql/telemetry/models/event.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@ class DriverConnectionParameters:
5454
host_info (HostDetails): Details about the host connection
5555
auth_mech (AuthMech): The authentication mechanism used
5656
auth_flow (AuthFlow): The authentication flow type
57-
auth_scope (str): The scope of authentication
58-
discovery_url (str): URL for service discovery
59-
allowed_volume_ingestion_paths (str): JSON string of allowed paths for volume operations
60-
azure_tenant_id (str): Azure tenant ID for Azure authentication
6157
socket_timeout (int): Connection timeout in milliseconds
6258
"""
6359

@@ -66,10 +62,6 @@ class DriverConnectionParameters:
6662
host_info: HostDetails
6763
auth_mech: Optional[AuthMech] = None
6864
auth_flow: Optional[AuthFlow] = None
69-
auth_scope: Optional[str] = None
70-
discovery_url: Optional[str] = None
71-
allowed_volume_ingestion_paths: Optional[str] = None
72-
azure_tenant_id: Optional[str] = None
7365
socket_timeout: Optional[int] = None
7466

7567
def to_json(self):

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ def get_auth_mechanism(auth_provider):
6060
if isinstance(auth_provider, AccessTokenAuthProvider):
6161
return AuthMech.PAT # Personal Access Token authentication
6262
elif isinstance(auth_provider, DatabricksOAuthProvider):
63-
return AuthMech.OAUTH # Databricks-managed OAuth flow
63+
return AuthMech.DATABRICKS_OAUTH # Databricks-managed OAuth flow
6464
elif isinstance(auth_provider, ExternalAuthProvider):
6565
return (
66-
AuthMech.EXTERNAL
67-
) # External identity provider (AWS IAM, Azure AD, etc.)
66+
AuthMech.EXTERNAL_AUTH
67+
) # External identity provider (AWS, Azure, etc.)
6868
return AuthMech.OTHER # Custom or unknown authentication provider
6969

7070
@staticmethod
@@ -269,29 +269,33 @@ def initialize_telemetry_client(
269269
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
270270
TelemetryClientFactory._initialize()
271271

272-
if telemetry_enabled:
273-
with TelemetryClientFactory._lock:
274-
if connection_uuid not in TelemetryClientFactory._clients:
275-
logger.info(
276-
f"Creating new TelemetryClient for connection {connection_uuid}"
277-
)
272+
with TelemetryClientFactory._lock:
273+
if connection_uuid not in TelemetryClientFactory._clients:
274+
logger.info(
275+
f"Creating new TelemetryClient for connection {connection_uuid}"
276+
)
277+
if telemetry_enabled:
278278
TelemetryClientFactory._clients[connection_uuid] = TelemetryClient(
279279
telemetry_enabled=telemetry_enabled,
280280
connection_uuid=connection_uuid,
281281
auth_provider=auth_provider,
282282
host_url=host_url,
283283
executor=TelemetryClientFactory._executor,
284284
)
285-
return TelemetryClientFactory._clients[connection_uuid]
286-
else:
287-
return NoopTelemetryClient()
285+
else:
286+
TelemetryClientFactory._clients[
287+
connection_uuid
288+
] = NoopTelemetryClient()
288289

289290
@staticmethod
290291
def get_telemetry_client(connection_uuid):
291292
"""Get the telemetry client for a specific connection"""
292293
if connection_uuid in TelemetryClientFactory._clients:
293294
return TelemetryClientFactory._clients[connection_uuid]
294295
else:
296+
logger.error(
297+
f"Telemetry client not initialized for connection {connection_uuid}"
298+
)
295299
return NoopTelemetryClient()
296300

297301
@staticmethod

tests/unit/test_telemetry.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_initialize_telemetry_client_enabled(self, mock_client_class):
182182
mock_client = MagicMock()
183183
mock_client_class.return_value = mock_client
184184

185-
client = TelemetryClientFactory.initialize_telemetry_client(
185+
TelemetryClientFactory.initialize_telemetry_client(
186186
telemetry_enabled=True,
187187
connection_uuid=connection_uuid,
188188
auth_provider=auth_provider,
@@ -197,7 +197,6 @@ def test_initialize_telemetry_client_enabled(self, mock_client_class):
197197
host_url=host_url,
198198
executor=TelemetryClientFactory._executor,
199199
)
200-
self.assertEqual(client, mock_client)
201200
self.assertEqual(TelemetryClientFactory._clients[connection_uuid], mock_client)
202201

203202
# Call again with the same connection_uuid
@@ -209,18 +208,18 @@ def test_initialize_telemetry_client_enabled(self, mock_client_class):
209208

210209
def test_initialize_telemetry_client_disabled(self):
211210
"""Test initializing a telemetry client when telemetry is disabled."""
212-
client = TelemetryClientFactory.initialize_telemetry_client(
211+
connection_uuid = "test-uuid"
212+
TelemetryClientFactory.initialize_telemetry_client(
213213
telemetry_enabled=False,
214-
connection_uuid="test-uuid",
214+
connection_uuid=connection_uuid,
215215
auth_provider=MagicMock(),
216216
host_url="test-host",
217217
)
218218

219-
# Verify a NoopTelemetryClient was returned
220-
self.assertIsInstance(client, NoopTelemetryClient)
221-
self.assertEqual(TelemetryClientFactory._clients, {}) # No client was stored
219+
# Verify a NoopTelemetryClient was stored
220+
self.assertIsInstance(TelemetryClientFactory._clients[connection_uuid], NoopTelemetryClient)
222221

223-
client2 = TelemetryClientFactory.get_telemetry_client("test-uuid")
222+
client2 = TelemetryClientFactory.get_telemetry_client(connection_uuid)
224223
self.assertIsInstance(client2, NoopTelemetryClient)
225224

226225
def test_get_telemetry_client_existing(self):
@@ -267,7 +266,7 @@ def test_close(self, mock_client_class, mock_executor_class):
267266
mock_executor_class.return_value = mock_executor2
268267
mock_client_class.return_value = mock_client2
269268

270-
client = TelemetryClientFactory.initialize_telemetry_client(
269+
TelemetryClientFactory.initialize_telemetry_client(
271270
telemetry_enabled=True,
272271
connection_uuid=connection_uuid2,
273272
auth_provider=MagicMock(),
@@ -279,7 +278,7 @@ def test_close(self, mock_client_class, mock_executor_class):
279278
self.assertIsNotNone(TelemetryClientFactory._executor)
280279
self.assertEqual(TelemetryClientFactory._executor, mock_executor2)
281280
self.assertIn(connection_uuid2, TelemetryClientFactory._clients)
282-
self.assertEqual(client, mock_client2)
281+
self.assertEqual(TelemetryClientFactory._clients[connection_uuid2], mock_client2)
283282

284283
# Verify new ThreadPoolExecutor was created
285284
self.assertEqual(mock_executor_class.call_count, 1)

0 commit comments

Comments
 (0)