Skip to content

Commit d237255

Browse files
committed
removed debugging, added TelemetryClientFactory
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 2c293a5 commit d237255

File tree

5 files changed

+231
-219
lines changed

5 files changed

+231
-219
lines changed

src/databricks/sql/client.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@
5454
)
5555
from databricks.sql.telemetry.telemetry_client import (
5656
TelemetryHelper,
57-
initialize_telemetry_client,
58-
get_telemetry_client,
59-
close_telemetry_client,
57+
TelemetryClientFactory,
6058
)
6159
from databricks.sql.telemetry.models.enums import DatabricksClientType
6260
from databricks.sql.telemetry.models.event import (
@@ -308,14 +306,14 @@ def read(self) -> Optional[OAuthToken]:
308306
kwargs.get("use_inline_params", False)
309307
)
310308

311-
initialize_telemetry_client(
309+
TelemetryClientFactory.initialize_telemetry_client(
312310
telemetry_enabled=self.telemetry_enabled,
313311
session_id_hex=self.get_session_id_hex(),
314312
auth_provider=auth_provider,
315313
host_url=self.host,
316314
)
317315

318-
self._telemetry_client = get_telemetry_client(
316+
self._telemetry_client = TelemetryClientFactory.get_telemetry_client(
319317
session_id_hex=self.get_session_id_hex()
320318
)
321319

@@ -472,7 +470,7 @@ def _close(self, close_cursors=True) -> None:
472470

473471
self.open = False
474472

475-
close_telemetry_client(self.get_session_id_hex())
473+
TelemetryClientFactory.close(self.get_session_id_hex())
476474

477475
def commit(self):
478476
"""No-op because Databricks does not support transactions"""

src/databricks/sql/exc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33

4-
from databricks.sql.telemetry.telemetry_client import get_telemetry_client
4+
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
55

66
logger = logging.getLogger(__name__)
77

@@ -22,7 +22,9 @@ def __init__(
2222

2323
error_name = self.__class__.__name__
2424
if session_id_hex:
25-
telemetry_client = get_telemetry_client(session_id_hex)
25+
telemetry_client = TelemetryClientFactory.get_telemetry_client(
26+
session_id_hex
27+
)
2628
telemetry_client.export_failure_log(error_name, self.message)
2729

2830
def __str__(self):

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 123 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -305,111 +305,131 @@ def close(self):
305305
self._flush()
306306

307307

308-
# Module-level state
309-
_clients: Dict[str, BaseTelemetryClient] = {}
310-
_executor: Optional[ThreadPoolExecutor] = None
311-
_initialized: bool = False
312-
_lock = threading.Lock()
313-
_original_excepthook = None
314-
_excepthook_installed = False
315-
316-
317-
def _initialize():
318-
"""Initialize the telemetry system if not already initialized"""
319-
global _initialized, _executor
320-
if not _initialized:
321-
_clients.clear()
322-
_executor = ThreadPoolExecutor(max_workers=10)
323-
_install_exception_hook()
324-
_initialized = True
325-
logger.debug("Telemetry system initialized with thread pool (max_workers=10)")
326-
327-
328-
def _install_exception_hook():
329-
"""Install global exception handler for unhandled exceptions"""
330-
global _excepthook_installed, _original_excepthook
331-
if not _excepthook_installed:
332-
_original_excepthook = sys.excepthook
333-
sys.excepthook = _handle_unhandled_exception
334-
_excepthook_installed = True
335-
logger.debug("Global exception handler installed for telemetry")
336-
337-
338-
def _handle_unhandled_exception(exc_type, exc_value, exc_traceback):
339-
"""Handle unhandled exceptions by sending telemetry and flushing thread pool"""
340-
logger.debug("Handling unhandled exception: %s", exc_type.__name__)
341-
342-
clients_to_close = list(_clients.values())
343-
for client in clients_to_close:
344-
client.close()
345-
346-
# Call the original exception handler to maintain normal behavior
347-
if _original_excepthook:
348-
_original_excepthook(exc_type, exc_value, exc_traceback)
349-
350-
351-
def initialize_telemetry_client(
352-
telemetry_enabled, session_id_hex, auth_provider, host_url
353-
):
354-
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
355-
try:
356-
with _lock:
357-
_initialize()
358-
if session_id_hex not in _clients:
359-
logger.debug(
360-
"Creating new TelemetryClient for connection %s", session_id_hex
361-
)
362-
if telemetry_enabled:
363-
_clients[session_id_hex] = TelemetryClient(
364-
telemetry_enabled=telemetry_enabled,
365-
session_id_hex=session_id_hex,
366-
auth_provider=auth_provider,
367-
host_url=host_url,
368-
executor=_executor,
369-
)
370-
print("i have initialized the telemetry client yes")
371-
else:
372-
_clients[session_id_hex] = NoopTelemetryClient()
373-
print("i have initialized the noop client yes")
374-
except Exception as e:
375-
logger.debug("Failed to initialize telemetry client: %s", e)
376-
# Fallback to NoopTelemetryClient to ensure connection doesn't fail
377-
_clients[session_id_hex] = NoopTelemetryClient()
378-
379-
380-
def get_telemetry_client(session_id_hex):
381-
"""Get the telemetry client for a specific connection"""
382-
try:
383-
if session_id_hex in _clients:
384-
return _clients[session_id_hex]
385-
else:
386-
logger.error(
387-
"Telemetry client not initialized for connection %s", session_id_hex
308+
class TelemetryClientFactory:
309+
"""
310+
Static factory class for creating and managing telemetry clients.
311+
It uses a thread pool to handle asynchronous operations.
312+
"""
313+
314+
_clients: Dict[
315+
str, BaseTelemetryClient
316+
] = {} # Map of session_id_hex -> BaseTelemetryClient
317+
_executor: Optional[ThreadPoolExecutor] = None
318+
_initialized: bool = False
319+
_lock = threading.Lock() # Thread safety for factory operations
320+
_original_excepthook = None
321+
_excepthook_installed = False
322+
323+
@classmethod
324+
def _initialize(cls):
325+
"""Initialize the factory if not already initialized"""
326+
327+
if not cls._initialized:
328+
cls._clients = {}
329+
cls._executor = ThreadPoolExecutor(
330+
max_workers=10
331+
) # Thread pool for async operations TODO: Decide on max workers
332+
cls._install_exception_hook()
333+
cls._initialized = True
334+
logger.debug(
335+
"TelemetryClientFactory initialized with thread pool (max_workers=10)"
388336
)
389-
return NoopTelemetryClient()
390-
except Exception as e:
391-
logger.debug("Failed to get telemetry client: %s", e)
392-
return NoopTelemetryClient()
393-
394-
395-
def close_telemetry_client(session_id_hex):
396-
"""Remove the telemetry client for a specific connection"""
397-
global _initialized, _executor
398-
with _lock:
399-
# if (telemetry_client := _clients.pop(session_id_hex, None)) is not None:
400-
if session_id_hex in _clients:
401-
telemetry_client = _clients.pop(session_id_hex)
402-
logger.debug("Removing telemetry client for connection %s", session_id_hex)
403-
telemetry_client.close()
404-
405-
# Shutdown executor if no more clients
337+
338+
@classmethod
339+
def _install_exception_hook(cls):
340+
"""Install global exception handler for unhandled exceptions"""
341+
if not cls._excepthook_installed:
342+
cls._original_excepthook = sys.excepthook
343+
sys.excepthook = cls._handle_unhandled_exception
344+
cls._excepthook_installed = True
345+
logger.debug("Global exception handler installed for telemetry")
346+
347+
@classmethod
348+
def _handle_unhandled_exception(cls, exc_type, exc_value, exc_traceback):
349+
"""Handle unhandled exceptions by sending telemetry and flushing thread pool"""
350+
logger.debug("Handling unhandled exception: %s", exc_type.__name__)
351+
352+
clients_to_close = list(cls._clients.values())
353+
for client in clients_to_close:
354+
client.close()
355+
356+
# Call the original exception handler to maintain normal behavior
357+
if cls._original_excepthook:
358+
cls._original_excepthook(exc_type, exc_value, exc_traceback)
359+
360+
@staticmethod
361+
def initialize_telemetry_client(
362+
telemetry_enabled,
363+
session_id_hex,
364+
auth_provider,
365+
host_url,
366+
):
367+
"""Initialize a telemetry client for a specific connection if telemetry is enabled"""
406368
try:
407-
if not _clients and _executor:
369+
370+
with TelemetryClientFactory._lock:
371+
TelemetryClientFactory._initialize()
372+
373+
if session_id_hex not in TelemetryClientFactory._clients:
374+
logger.debug(
375+
"Creating new TelemetryClient for connection %s",
376+
session_id_hex,
377+
)
378+
if telemetry_enabled:
379+
TelemetryClientFactory._clients[
380+
session_id_hex
381+
] = TelemetryClient(
382+
telemetry_enabled=telemetry_enabled,
383+
session_id_hex=session_id_hex,
384+
auth_provider=auth_provider,
385+
host_url=host_url,
386+
executor=TelemetryClientFactory._executor,
387+
)
388+
else:
389+
TelemetryClientFactory._clients[
390+
session_id_hex
391+
] = NoopTelemetryClient()
392+
except Exception as e:
393+
logger.debug("Failed to initialize telemetry client: %s", e)
394+
# Fallback to NoopTelemetryClient to ensure connection doesn't fail
395+
TelemetryClientFactory._clients[session_id_hex] = NoopTelemetryClient()
396+
397+
@staticmethod
398+
def get_telemetry_client(session_id_hex):
399+
"""Get the telemetry client for a specific connection"""
400+
try:
401+
if session_id_hex in TelemetryClientFactory._clients:
402+
return TelemetryClientFactory._clients[session_id_hex]
403+
else:
404+
logger.error(
405+
"Telemetry client not initialized for connection %s",
406+
session_id_hex,
407+
)
408+
return NoopTelemetryClient()
409+
except Exception as e:
410+
logger.debug("Failed to get telemetry client: %s", e)
411+
return NoopTelemetryClient()
412+
413+
@staticmethod
414+
def close(session_id_hex):
415+
"""Close and remove the telemetry client for a specific connection"""
416+
417+
with TelemetryClientFactory._lock:
418+
if (
419+
telemetry_client := TelemetryClientFactory._clients.pop(
420+
session_id_hex, None
421+
)
422+
) is not None:
423+
logger.debug(
424+
"Removing telemetry client for connection %s", session_id_hex
425+
)
426+
telemetry_client.close()
427+
428+
# Shutdown executor if no more clients
429+
if not TelemetryClientFactory._clients and TelemetryClientFactory._executor:
408430
logger.debug(
409431
"No more telemetry clients, shutting down thread pool executor"
410432
)
411-
_executor.shutdown(wait=True)
412-
_executor = None
413-
_initialized = False
414-
except Exception as e:
415-
logger.debug("Failed to shutdown thread pool executor: %s", e)
433+
TelemetryClientFactory._executor.shutdown(wait=True)
434+
TelemetryClientFactory._executor = None
435+
TelemetryClientFactory._initialized = False

tests/unit/test_client.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,6 @@ def test_negative_fetch_throws_exception(self):
336336
result_set.fetchmany(-1)
337337

338338
def test_context_manager_closes_cursor(self):
339-
print("hellow")
340339
mock_close = Mock()
341340
with client.Cursor(Mock(), Mock()) as cursor:
342341
cursor.close = mock_close
@@ -352,8 +351,7 @@ def test_context_manager_closes_cursor(self):
352351
cursor.close.assert_called()
353352

354353
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
355-
def test_a_context_manager_closes_connection(self, mock_client_class):
356-
print("hellow1")
354+
def test_context_manager_closes_connection(self, mock_client_class):
357355
instance = mock_client_class.return_value
358356

359357
mock_open_session_resp = MagicMock(spec=TOpenSessionResp)()
@@ -792,7 +790,6 @@ def test_cursor_context_manager_handles_exit_exception(self):
792790

793791
def test_connection_close_handles_cursor_close_exception(self):
794792
"""Test that _close handles exceptions from cursor.close() properly."""
795-
print("banana")
796793
cursors_closed = []
797794

798795
def mock_close_with_exception():

0 commit comments

Comments
 (0)