Skip to content

Commit 906c187

Browse files
committed
changed TelemetryClientFactory to a static class and made changes in unit tests accordingly
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 5f8dff5 commit 906c187

File tree

3 files changed

+152
-116
lines changed

3 files changed

+152
-116
lines changed

src/databricks/sql/client.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
TOperationState,
5151
)
5252
from databricks.sql.telemetry.telemetry_client import (
53-
telemetry_client_factory,
53+
TelemetryClientFactory,
5454
TelemetryHelper,
5555
)
5656
from databricks.sql.telemetry.models.enums import DatabricksClientType
@@ -303,6 +303,13 @@ def read(self) -> Optional[OAuthToken]:
303303
kwargs.get("use_inline_params", False)
304304
)
305305

306+
self.telemetry_client = TelemetryClientFactory.initialize_telemetry_client(
307+
telemetry_enabled=self.telemetry_enabled,
308+
connection_uuid=self.get_session_id_hex(),
309+
auth_provider=auth_provider,
310+
host_url=self.host,
311+
)
312+
306313
driver_connection_params = DriverConnectionParameters(
307314
http_path=http_path,
308315
mode=DatabricksClientType.THRIFT,
@@ -313,16 +320,11 @@ def read(self) -> Optional[OAuthToken]:
313320
socket_timeout=kwargs.get("_socket_timeout", None),
314321
)
315322

316-
self.telemetry_client = telemetry_client_factory.initialize_telemetry_client(
317-
telemetry_enabled=self.telemetry_enabled,
318-
connection_uuid=self.get_session_id_hex(),
319-
auth_provider=auth_provider,
320-
user_agent=useragent_header,
323+
self.telemetry_client.export_initial_telemetry_log(
321324
driver_connection_params=driver_connection_params,
325+
user_agent=useragent_header,
322326
)
323327

324-
self.telemetry_client.export_initial_telemetry_log()
325-
326328
def _set_use_inline_params_with_warning(self, value: Union[bool, str]):
327329
"""Valid values are True, False, and "silent"
328330

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __new__(cls):
126126
cls._instance = super(NoopTelemetryClient, cls).__new__(cls)
127127
return cls._instance
128128

129-
def export_initial_telemetry_log(self, **kwargs):
129+
def export_initial_telemetry_log(self, driver_connection_params, user_agent):
130130
pass
131131

132132
def close(self):
@@ -144,20 +144,19 @@ def __init__(
144144
telemetry_enabled,
145145
connection_uuid,
146146
auth_provider,
147-
user_agent,
148-
driver_connection_params,
147+
host_url,
149148
executor,
150149
):
151150
logger.info(f"Initializing TelemetryClient for connection: {connection_uuid}")
152151
self.telemetry_enabled = telemetry_enabled
153152
self.batch_size = 10 # TODO: Decide on batch size
154153
self.connection_uuid = connection_uuid
155154
self.auth_provider = auth_provider
156-
self.user_agent = user_agent
155+
self.user_agent = None
157156
self.events_batch = []
158157
self.lock = threading.Lock()
159-
self.driver_connection_params = driver_connection_params
160-
self.host_url = driver_connection_params.host_info.host_url
158+
self.driver_connection_params = None
159+
self.host_url = host_url
161160
self.executor = executor
162161

163162
def export_event(self, event):
@@ -210,11 +209,14 @@ def _send_telemetry(self, events):
210209
except Exception as e:
211210
logger.error(f"Failed to submit telemetry request: {e}")
212211

213-
def export_initial_telemetry_log(self):
212+
def export_initial_telemetry_log(self, driver_connection_params, user_agent):
214213
logger.info(
215214
f"Exporting initial telemetry log for connection {self.connection_uuid}"
216215
)
217216

217+
self.driver_connection_params = driver_connection_params
218+
self.user_agent = user_agent
219+
218220
telemetry_frontend_log = TelemetryFrontendLog(
219221
frontend_log_event_id=str(uuid.uuid4()),
220222
context=FrontendLogContext(
@@ -237,73 +239,78 @@ def close(self):
237239
"""Flush remaining events before closing"""
238240
logger.info(f"Closing TelemetryClient for connection {self.connection_uuid}")
239241
self.flush()
240-
telemetry_client_factory.close(self.connection_uuid)
242+
TelemetryClientFactory.close(self.connection_uuid)
241243

242244

243245
class TelemetryClientFactory:
244246
"""
245-
Factory class for creating and managing telemetry clients.
247+
Static factory class for creating and managing telemetry clients.
246248
It uses a thread pool to handle asynchronous operations.
247249
"""
248250

249-
_instance = None
250-
251-
def __new__(cls):
252-
if cls._instance is None:
253-
cls._instance = super(TelemetryClientFactory, cls).__new__(cls)
254-
cls._instance._initialized = False
255-
return cls._instance
256-
257-
def __init__(self):
258-
if self._initialized:
259-
return
251+
_clients = {} # Map of connection_uuid -> TelemetryClient
252+
_executor = None
253+
_initialized = False
260254

261-
self._clients = {} # Map of connection_uuid -> TelemetryClient
262-
self.executor = ThreadPoolExecutor(
263-
max_workers=10
264-
) # Thread pool for async operations TODO: Decide on max workers
265-
self._initialized = True
255+
@classmethod
256+
def _initialize(cls):
257+
"""Initialize the factory if not already initialized"""
258+
if not cls._initialized:
259+
logger.info("Initializing TelemetryClientFactory")
260+
cls._clients = {}
261+
cls._executor = ThreadPoolExecutor(
262+
max_workers=10
263+
) # Thread pool for async operations TODO: Decide on max workers
264+
cls._initialized = True
265+
logger.debug(
266+
"TelemetryClientFactory initialized with thread pool (max_workers=10)"
267+
)
266268

269+
@staticmethod
267270
def initialize_telemetry_client(
268-
self,
269271
telemetry_enabled,
270272
connection_uuid,
271273
auth_provider,
272-
user_agent,
273-
driver_connection_params,
274+
host_url,
274275
):
275276
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
277+
TelemetryClientFactory._initialize()
278+
276279
if telemetry_enabled:
277-
if connection_uuid not in self._clients:
278-
self._clients[connection_uuid] = TelemetryClient(
280+
if connection_uuid not in TelemetryClientFactory._clients:
281+
logger.info(
282+
f"Creating new TelemetryClient for connection {connection_uuid}"
283+
)
284+
TelemetryClientFactory._clients[connection_uuid] = TelemetryClient(
279285
telemetry_enabled=telemetry_enabled,
280286
connection_uuid=connection_uuid,
281287
auth_provider=auth_provider,
282-
user_agent=user_agent,
283-
driver_connection_params=driver_connection_params,
284-
executor=self.executor,
288+
host_url=host_url,
289+
executor=TelemetryClientFactory._executor,
285290
)
286-
return self._clients[connection_uuid]
291+
return TelemetryClientFactory._clients[connection_uuid]
287292
else:
288293
return NoopTelemetryClient()
289294

290-
def get_telemetry_client(self, connection_uuid):
295+
@staticmethod
296+
def get_telemetry_client(connection_uuid):
291297
"""Get the telemetry client for a specific connection"""
292-
if connection_uuid in self._clients:
293-
return self._clients[connection_uuid]
298+
if connection_uuid in TelemetryClientFactory._clients:
299+
return TelemetryClientFactory._clients[connection_uuid]
294300
else:
295301
return NoopTelemetryClient()
296302

297-
def close(self, connection_uuid):
298-
if connection_uuid in self._clients:
303+
@staticmethod
304+
def close(connection_uuid):
305+
"""Close and remove the telemetry client for a specific connection"""
306+
307+
if connection_uuid in TelemetryClientFactory._clients:
299308
logger.debug(f"Removing telemetry client for connection {connection_uuid}")
300-
del self._clients[connection_uuid]
309+
del TelemetryClientFactory._clients[connection_uuid]
301310

302311
# Shutdown executor if no more clients
303-
if not self._clients:
312+
if not TelemetryClientFactory._clients and TelemetryClientFactory._executor:
304313
logger.info("No more telemetry clients, shutting down thread pool executor")
305-
self.executor.shutdown(wait=True)
306-
307-
308-
# Create a global instance
309-
telemetry_client_factory = TelemetryClientFactory()
314+
TelemetryClientFactory._executor.shutdown(wait=True)
315+
TelemetryClientFactory._executor = None
316+
TelemetryClientFactory._initialized = False

0 commit comments

Comments
 (0)