Skip to content

Commit 34b63e4

Browse files
committed
added teardown to all tests except finalizer test (gc collect)
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 4466821 commit 34b63e4

File tree

3 files changed

+116
-78
lines changed

3 files changed

+116
-78
lines changed

src/databricks/sql/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def _close(self, close_cursors=True) -> None:
446446
if close_cursors:
447447
for cursor in self._cursors:
448448
cursor.close()
449-
449+
print(f"Closing session {self.get_session_id_hex()}")
450450
logger.info(f"Closing session {self.get_session_id_hex()}")
451451
if not self.open:
452452
logger.debug("Session appears to have been closed already")

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 77 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -31,79 +31,79 @@
3131
logger = logging.getLogger(__name__)
3232

3333

34-
class DebugLock:
35-
"""A wrapper around threading.Lock that provides detailed debugging for lock acquisition/release"""
36-
37-
def __init__(self, name: str = "DebugLock"):
38-
self._lock = threading.Lock()
39-
self._name = name
40-
self._owner: Optional[str] = None
41-
self._waiters: List[str] = []
42-
self._debug_logger = logging.getLogger(f"{__name__}.{name}")
43-
# Ensure debug logging is visible
44-
if not self._debug_logger.handlers:
45-
handler = logging.StreamHandler()
46-
formatter = logging.Formatter(
47-
":lock: %(asctime)s [%(threadName)s-%(thread)d] LOCK-%(name)s: %(message)s"
48-
)
49-
handler.setFormatter(formatter)
50-
self._debug_logger.addHandler(handler)
51-
self._debug_logger.setLevel(logging.DEBUG)
52-
53-
def acquire(self, blocking=True, timeout=-1):
54-
current = threading.current_thread()
55-
thread_info = f"{current.name}-{current.ident}"
56-
if self._owner:
57-
self._debug_logger.warning(
58-
f": WAITING: {thread_info} waiting for lock held by {self._owner}"
59-
)
60-
self._waiters.append(thread_info)
61-
else:
62-
self._debug_logger.debug(
63-
f": TRYING: {thread_info} attempting to acquire lock"
64-
)
65-
# Try to acquire the lock
66-
acquired = self._lock.acquire(blocking, timeout)
67-
if acquired:
68-
self._owner = thread_info
69-
self._debug_logger.info(f": ACQUIRED: {thread_info} got the lock")
70-
if self._waiters:
71-
self._debug_logger.info(
72-
f": WAITERS: {len(self._waiters)} threads waiting: {self._waiters}"
73-
)
74-
else:
75-
self._debug_logger.error(
76-
f": FAILED: {thread_info} failed to acquire lock (timeout)"
77-
)
78-
if thread_info in self._waiters:
79-
self._waiters.remove(thread_info)
80-
return acquired
81-
82-
def release(self):
83-
current = threading.current_thread()
84-
thread_info = f"{current.name}-{current.ident}"
85-
if self._owner != thread_info:
86-
self._debug_logger.error(
87-
f": ERROR: {thread_info} trying to release lock owned by {self._owner}"
88-
)
89-
else:
90-
self._debug_logger.info(f": RELEASED: {thread_info} released the lock")
91-
self._owner = None
92-
# Remove from waiters if present
93-
if thread_info in self._waiters:
94-
self._waiters.remove(thread_info)
95-
if self._waiters:
96-
self._debug_logger.info(
97-
f": NEXT: {len(self._waiters)} threads still waiting: {self._waiters}"
98-
)
99-
self._lock.release()
100-
101-
def __enter__(self):
102-
self.acquire()
103-
return self
104-
105-
def __exit__(self, exc_type, exc_val, exc_tb):
106-
self.release()
34+
# class DebugLock:
35+
# """A wrapper around threading.Lock that provides detailed debugging for lock acquisition/release"""
36+
37+
# def __init__(self, name: str = "DebugLock"):
38+
# self._lock = threading.Lock()
39+
# self._name = name
40+
# self._owner: Optional[str] = None
41+
# self._waiters: List[str] = []
42+
# self._debug_logger = logging.getLogger(f"{__name__}.{name}")
43+
# # Ensure debug logging is visible
44+
# if not self._debug_logger.handlers:
45+
# handler = logging.StreamHandler()
46+
# formatter = logging.Formatter(
47+
# ":lock: %(asctime)s [%(threadName)s-%(thread)d] LOCK-%(name)s: %(message)s"
48+
# )
49+
# handler.setFormatter(formatter)
50+
# self._debug_logger.addHandler(handler)
51+
# self._debug_logger.setLevel(logging.DEBUG)
52+
53+
# def acquire(self, blocking=True, timeout=-1):
54+
# current = threading.current_thread()
55+
# thread_info = f"{current.name}-{current.ident}"
56+
# if self._owner:
57+
# self._debug_logger.warning(
58+
# f": WAITING: {thread_info} waiting for lock held by {self._owner}"
59+
# )
60+
# self._waiters.append(thread_info)
61+
# else:
62+
# self._debug_logger.debug(
63+
# f": TRYING: {thread_info} attempting to acquire lock"
64+
# )
65+
# # Try to acquire the lock
66+
# acquired = self._lock.acquire(blocking, timeout)
67+
# if acquired:
68+
# self._owner = thread_info
69+
# self._debug_logger.info(f": ACQUIRED: {thread_info} got the lock")
70+
# if self._waiters:
71+
# self._debug_logger.info(
72+
# f": WAITERS: {len(self._waiters)} threads waiting: {self._waiters}"
73+
# )
74+
# else:
75+
# self._debug_logger.error(
76+
# f": FAILED: {thread_info} failed to acquire lock (timeout)"
77+
# )
78+
# if thread_info in self._waiters:
79+
# self._waiters.remove(thread_info)
80+
# return acquired
81+
82+
# def release(self):
83+
# current = threading.current_thread()
84+
# thread_info = f"{current.name}-{current.ident}"
85+
# if self._owner != thread_info:
86+
# self._debug_logger.error(
87+
# f": ERROR: {thread_info} trying to release lock owned by {self._owner}"
88+
# )
89+
# else:
90+
# self._debug_logger.info(f": RELEASED: {thread_info} released the lock")
91+
# self._owner = None
92+
# # Remove from waiters if present
93+
# if thread_info in self._waiters:
94+
# self._waiters.remove(thread_info)
95+
# if self._waiters:
96+
# self._debug_logger.info(
97+
# f": NEXT: {len(self._waiters)} threads still waiting: {self._waiters}"
98+
# )
99+
# self._lock.release()
100+
101+
# def __enter__(self):
102+
# self.acquire()
103+
# return self
104+
105+
# def __exit__(self, exc_type, exc_val, exc_tb):
106+
# self.release()
107107

108108

109109
class TelemetryHelper:
@@ -430,10 +430,10 @@ class TelemetryClientFactory:
430430
] = {} # Map of session_id_hex -> BaseTelemetryClient
431431
_executor: Optional[ThreadPoolExecutor] = None
432432
_initialized: bool = False
433-
# _lock = threading.Lock() # Thread safety for factory operations
434-
_lock = DebugLock(
435-
"TelemetryClientFactory"
436-
) # Thread safety for factory operations with debugging
433+
_lock = threading.Lock() # Thread safety for factory operations
434+
# _lock = DebugLock(
435+
# "TelemetryClientFactory"
436+
# ) # Thread safety for factory operations with debugging
437437
_original_excepthook = None
438438
_excepthook_installed = False
439439

tests/unit/test_client.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,44 @@ class ClientTestSuite(unittest.TestCase):
8383
"access_token": "tok",
8484
}
8585

86+
def setUp(self):
87+
"""Set up connection tracking before each test"""
88+
self.tracked_connections = []
89+
90+
# Store the original connect function
91+
self.original_connect = databricks.sql.connect
92+
93+
def patched_connect(*args, **kwargs):
94+
"""Wrapper that tracks all created connections"""
95+
conn = self.original_connect(*args, **kwargs)
96+
97+
# Skip tracking for finalizer tests to allow garbage collection
98+
if not (hasattr(self, '_testMethodName') and 'finalizer' in self._testMethodName):
99+
self.tracked_connections.append(conn)
100+
101+
return conn
102+
103+
# Apply the patch to track connections
104+
self.connect_patcher = patch('databricks.sql.connect', patched_connect)
105+
self.connect_patcher.start()
106+
107+
def tearDown(self):
108+
"""Clean up connections after each test"""
109+
# Close all tracked connections
110+
for conn in self.tracked_connections:
111+
try:
112+
if hasattr(conn, 'open') and conn.open:
113+
conn.close()
114+
except Exception as e:
115+
# Log the error but don't fail the test
116+
print(f"Warning: Error closing connection in tearDown: {e}")
117+
118+
# Stop the connect patcher
119+
self.connect_patcher.stop()
120+
121+
# Clear the tracked connections list
122+
self.tracked_connections.clear()
123+
86124
@patch("%s.client.ThriftBackend" % PACKAGE_NAME)
87125
def test_close_uses_the_correct_session_id(self, mock_client_class):
88126
instance = mock_client_class.return_value

0 commit comments

Comments
 (0)