Skip to content

Commit db9c2c7

Browse files
committed
convert TelemetryClientFactory to module-level functions, replace NoopTelemetryClient class with NOOP_TELEMETRY_CLIENT singleton, updated tests accordingly
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent e5609c2 commit db9c2c7

File tree

4 files changed

+167
-251
lines changed

4 files changed

+167
-251
lines changed

src/databricks/sql/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@
5353
TOperationState,
5454
)
5555
from databricks.sql.telemetry.telemetry_client import (
56-
TelemetryClientFactory,
5756
TelemetryHelper,
57+
initialize_telemetry_client,
58+
get_telemetry_client,
5859
)
5960
from databricks.sql.telemetry.models.enums import DatabricksClientType
6061
from databricks.sql.telemetry.models.event import (
@@ -306,14 +307,14 @@ def read(self) -> Optional[OAuthToken]:
306307
kwargs.get("use_inline_params", False)
307308
)
308309

309-
TelemetryClientFactory.initialize_telemetry_client(
310+
initialize_telemetry_client(
310311
telemetry_enabled=self.telemetry_enabled,
311312
connection_uuid=self.get_session_id_hex(),
312313
auth_provider=auth_provider,
313314
host_url=self.host,
314315
)
315316

316-
self._telemetry_client = TelemetryClientFactory.get_telemetry_client(
317+
self._telemetry_client = get_telemetry_client(
317318
connection_uuid=self.get_session_id_hex()
318319
)
319320

src/databricks/sql/exc.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import traceback
44

5-
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
5+
from databricks.sql.telemetry.telemetry_client import get_telemetry_client
66

77
logger = logging.getLogger(__name__)
88

@@ -22,9 +22,7 @@ def __init__(
2222

2323
error_name = self.__class__.__name__
2424
if connection_uuid:
25-
telemetry_client = TelemetryClientFactory.get_telemetry_client(
26-
connection_uuid
27-
)
25+
telemetry_client = get_telemetry_client(connection_uuid)
2826
telemetry_client.export_failure_log(error_name, self.message)
2927

3028
def __str__(self):

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 107 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -115,27 +115,16 @@ def close(self):
115115
pass
116116

117117

118-
class NoopTelemetryClient(BaseTelemetryClient):
119-
"""
120-
NoopTelemetryClient is a telemetry client that does not send any events to the server.
121-
It is used when telemetry is disabled.
122-
"""
123-
124-
_instance = None
125-
126-
def __new__(cls):
127-
if cls._instance is None:
128-
cls._instance = super(NoopTelemetryClient, cls).__new__(cls)
129-
return cls._instance
130-
131-
def export_initial_telemetry_log(self, driver_connection_params, user_agent):
132-
pass
133-
134-
def export_failure_log(self, error_name, error_message):
135-
pass
136-
137-
def close(self):
138-
pass
118+
# A single instance of the no-op client that can be reused
119+
NOOP_TELEMETRY_CLIENT = type(
120+
"NoopTelemetryClient",
121+
(BaseTelemetryClient,),
122+
{
123+
"export_initial_telemetry_log": lambda self, *args, **kwargs: None,
124+
"export_failure_log": lambda self, *args, **kwargs: None,
125+
"close": lambda self: None,
126+
},
127+
)()
139128

140129

141130
class TelemetryClient(BaseTelemetryClient):
@@ -301,129 +290,111 @@ def close(self):
301290
"""Flush remaining events before closing"""
302291
logger.debug("Closing TelemetryClient for connection %s", self._connection_uuid)
303292
self._flush()
304-
TelemetryClientFactory.close(self._connection_uuid)
305-
306-
307-
class TelemetryClientFactory:
308-
"""
309-
Static factory class for creating and managing telemetry clients.
310-
It uses a thread pool to handle asynchronous operations.
311-
"""
312-
313-
_clients: Dict[
314-
str, BaseTelemetryClient
315-
] = {} # Map of connection_uuid -> BaseTelemetryClient
316-
_executor: Optional[ThreadPoolExecutor] = None
317-
_initialized: bool = False
318-
_lock = threading.Lock() # Thread safety for factory operations
319-
_original_excepthook = None
320-
_excepthook_installed = False
321-
322-
@classmethod
323-
def _initialize(cls):
324-
"""Initialize the factory if not already initialized"""
325-
with cls._lock:
326-
if not cls._initialized:
327-
cls._clients = {}
328-
cls._executor = ThreadPoolExecutor(
329-
max_workers=10
330-
) # Thread pool for async operations TODO: Decide on max workers
331-
cls._install_exception_hook()
332-
cls._initialized = True
333-
logger.debug(
334-
"TelemetryClientFactory initialized with thread pool (max_workers=10)"
335-
)
336-
337-
@classmethod
338-
def _install_exception_hook(cls):
339-
"""Install global exception handler for unhandled exceptions"""
340-
if not cls._excepthook_installed:
341-
cls._original_excepthook = sys.excepthook
342-
sys.excepthook = cls._handle_unhandled_exception
343-
cls._excepthook_installed = True
344-
logger.debug("Global exception handler installed for telemetry")
293+
_remove_telemetry_client(self._connection_uuid)
294+
295+
296+
# Module-level state
297+
_clients: Dict[str, BaseTelemetryClient] = {}
298+
_executor: Optional[ThreadPoolExecutor] = None
299+
_initialized: bool = False
300+
_lock = threading.Lock()
301+
_original_excepthook = None
302+
_excepthook_installed = False
303+
304+
305+
def _initialize():
306+
"""Initialize the telemetry system if not already initialized"""
307+
global _initialized, _executor
308+
with _lock:
309+
if not _initialized:
310+
_clients.clear()
311+
_executor = ThreadPoolExecutor(max_workers=10)
312+
_install_exception_hook()
313+
_initialized = True
314+
logger.debug(
315+
"Telemetry system initialized with thread pool (max_workers=10)"
316+
)
345317

346-
@classmethod
347-
def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback):
348-
"""Handle unhandled exceptions by sending telemetry and flushing thread pool"""
349-
logger.debug("Handling unhandled exception: %s", exc_type.__name__)
350318

351-
clients_to_close = list(cls._clients.values())
352-
for client in clients_to_close:
353-
client.close()
319+
def _install_exception_hook():
320+
"""Install global exception handler for unhandled exceptions"""
321+
global _excepthook_installed, _original_excepthook
322+
if not _excepthook_installed:
323+
_original_excepthook = sys.excepthook
324+
sys.excepthook = _handle_unhandled_exception
325+
_excepthook_installed = True
326+
logger.debug("Global exception handler installed for telemetry")
354327

355-
# Call the original exception handler to maintain normal behavior
356-
if cls._original_excepthook:
357-
cls._original_excepthook(exc_type, exc_value, exc_traceback)
358328

359-
@staticmethod
360-
def initialize_telemetry_client(
361-
telemetry_enabled,
362-
connection_uuid,
363-
auth_provider,
364-
host_url,
365-
):
366-
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
367-
try:
368-
TelemetryClientFactory._initialize()
329+
def _handle_unhandled_exception(exc_type, exc_value, exc_traceback):
330+
"""Handle unhandled exceptions by sending telemetry and flushing thread pool"""
331+
logger.debug("Handling unhandled exception: %s", exc_type.__name__)
369332

370-
with TelemetryClientFactory._lock:
371-
if connection_uuid not in TelemetryClientFactory._clients:
372-
logger.debug(
373-
"Creating new TelemetryClient for connection %s",
374-
connection_uuid,
375-
)
376-
if telemetry_enabled:
377-
TelemetryClientFactory._clients[
378-
connection_uuid
379-
] = TelemetryClient(
380-
telemetry_enabled=telemetry_enabled,
381-
connection_uuid=connection_uuid,
382-
auth_provider=auth_provider,
383-
host_url=host_url,
384-
executor=TelemetryClientFactory._executor,
385-
)
386-
else:
387-
TelemetryClientFactory._clients[
388-
connection_uuid
389-
] = NoopTelemetryClient()
390-
except Exception as e:
391-
logger.debug("Failed to initialize telemetry client: %s", e)
392-
# Fallback to NoopTelemetryClient to ensure connection doesn't fail
393-
TelemetryClientFactory._clients[connection_uuid] = NoopTelemetryClient()
333+
clients_to_close = list(_clients.values())
334+
for client in clients_to_close:
335+
client.close()
394336

395-
@staticmethod
396-
def get_telemetry_client(connection_uuid):
397-
"""Get the telemetry client for a specific connection"""
398-
try:
399-
if connection_uuid in TelemetryClientFactory._clients:
400-
return TelemetryClientFactory._clients[connection_uuid]
401-
else:
402-
logger.error(
403-
"Telemetry client not initialized for connection %s",
404-
connection_uuid,
405-
)
406-
return NoopTelemetryClient()
407-
except Exception as e:
408-
logger.debug("Failed to get telemetry client: %s", e)
409-
return NoopTelemetryClient()
337+
# Call the original exception handler to maintain normal behavior
338+
if _original_excepthook:
339+
_original_excepthook(exc_type, exc_value, exc_traceback)
410340

411-
@staticmethod
412-
def close(connection_uuid):
413-
"""Close and remove the telemetry client for a specific connection"""
414341

415-
with TelemetryClientFactory._lock:
416-
if connection_uuid in TelemetryClientFactory._clients:
417-
logger.debug(
418-
"Removing telemetry client for connection %s", connection_uuid
419-
)
420-
TelemetryClientFactory._clients.pop(connection_uuid, None)
342+
def initialize_telemetry_client(
343+
telemetry_enabled, connection_uuid, auth_provider, host_url
344+
):
345+
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
346+
try:
347+
_initialize()
421348

422-
# Shutdown executor if no more clients
423-
if not TelemetryClientFactory._clients and TelemetryClientFactory._executor:
349+
with _lock:
350+
if connection_uuid not in _clients:
424351
logger.debug(
425-
"No more telemetry clients, shutting down thread pool executor"
352+
"Creating new TelemetryClient for connection %s", connection_uuid
426353
)
427-
TelemetryClientFactory._executor.shutdown(wait=True)
428-
TelemetryClientFactory._executor = None
429-
TelemetryClientFactory._initialized = False
354+
if telemetry_enabled:
355+
_clients[connection_uuid] = TelemetryClient(
356+
telemetry_enabled=telemetry_enabled,
357+
connection_uuid=connection_uuid,
358+
auth_provider=auth_provider,
359+
host_url=host_url,
360+
executor=_executor,
361+
)
362+
else:
363+
_clients[connection_uuid] = NOOP_TELEMETRY_CLIENT
364+
except Exception as e:
365+
logger.debug("Failed to initialize telemetry client: %s", e)
366+
# Fallback to NoopTelemetryClient to ensure connection doesn't fail
367+
_clients[connection_uuid] = NOOP_TELEMETRY_CLIENT
368+
369+
370+
def get_telemetry_client(connection_uuid):
371+
"""Get the telemetry client for a specific connection"""
372+
try:
373+
if connection_uuid in _clients:
374+
return _clients[connection_uuid]
375+
else:
376+
logger.error(
377+
"Telemetry client not initialized for connection %s", connection_uuid
378+
)
379+
return NOOP_TELEMETRY_CLIENT
380+
except Exception as e:
381+
logger.debug("Failed to get telemetry client: %s", e)
382+
return NOOP_TELEMETRY_CLIENT
383+
384+
385+
def _remove_telemetry_client(connection_uuid):
386+
"""Remove the telemetry client for a specific connection"""
387+
global _initialized, _executor
388+
with _lock:
389+
if connection_uuid in _clients:
390+
logger.debug("Removing telemetry client for connection %s", connection_uuid)
391+
_clients.pop(connection_uuid, None)
392+
393+
# Shutdown executor if no more clients
394+
if not _clients and _executor:
395+
logger.debug(
396+
"No more telemetry clients, shutting down thread pool executor"
397+
)
398+
_executor.shutdown(wait=True)
399+
_executor = None
400+
_initialized = False

0 commit comments

Comments
 (0)