Skip to content

Commit 9615519

Browse files
committed
changed unittest to pytest
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent e218ef9 commit 9615519

File tree

1 file changed

+141
-99
lines changed

1 file changed

+141
-99
lines changed

tests/unit/test_telemetry.py

Lines changed: 141 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import unittest
21
import uuid
2+
import pytest
33
import requests
44
from unittest.mock import patch, MagicMock, call
55

@@ -20,112 +20,161 @@
2020
AccessTokenAuthProvider,
2121
)
2222

23-
class TestNoopTelemetryClient(unittest.TestCase):
23+
24+
@pytest.fixture
25+
def noop_telemetry_client():
26+
"""Fixture for NoopTelemetryClient."""
27+
return NoopTelemetryClient()
28+
29+
30+
@pytest.fixture
31+
def telemetry_client_setup():
32+
"""Fixture for TelemetryClient setup data."""
33+
connection_uuid = str(uuid.uuid4())
34+
auth_provider = AccessTokenAuthProvider("test-token")
35+
host_url = "test-host"
36+
executor = MagicMock()
37+
38+
client = TelemetryClient(
39+
telemetry_enabled=True,
40+
connection_uuid=connection_uuid,
41+
auth_provider=auth_provider,
42+
host_url=host_url,
43+
executor=executor,
44+
)
45+
46+
return {
47+
"client": client,
48+
"connection_uuid": connection_uuid,
49+
"auth_provider": auth_provider,
50+
"host_url": host_url,
51+
"executor": executor,
52+
}
53+
54+
55+
@pytest.fixture
56+
def telemetry_factory_reset():
57+
"""Fixture to reset TelemetryClientFactory state before each test."""
58+
# Reset the static class state before each test
59+
TelemetryClientFactory._clients = {}
60+
TelemetryClientFactory._executor = None
61+
TelemetryClientFactory._initialized = False
62+
yield
63+
# Cleanup after test if needed
64+
TelemetryClientFactory._clients = {}
65+
if TelemetryClientFactory._executor:
66+
TelemetryClientFactory._executor.shutdown(wait=True)
67+
TelemetryClientFactory._executor = None
68+
TelemetryClientFactory._initialized = False
69+
70+
71+
class TestNoopTelemetryClient:
2472
"""Tests for the NoopTelemetryClient class."""
2573

2674
def test_singleton(self):
2775
"""Test that NoopTelemetryClient is a singleton."""
2876
client1 = NoopTelemetryClient()
2977
client2 = NoopTelemetryClient()
30-
self.assertIs(client1, client2)
78+
assert client1 is client2
3179

32-
def test_export_initial_telemetry_log(self):
80+
def test_export_initial_telemetry_log(self, noop_telemetry_client):
3381
"""Test that export_initial_telemetry_log does nothing."""
34-
client = NoopTelemetryClient()
35-
client.export_initial_telemetry_log(driver_connection_params=MagicMock(), user_agent="test")
82+
noop_telemetry_client.export_initial_telemetry_log(
83+
driver_connection_params=MagicMock(), user_agent="test"
84+
)
3685

37-
def test_close(self):
86+
def test_close(self, noop_telemetry_client):
3887
"""Test that close does nothing."""
39-
client = NoopTelemetryClient()
40-
client.close()
88+
noop_telemetry_client.close()
4189

4290

43-
class TestTelemetryClient(unittest.TestCase):
91+
class TestTelemetryClient:
4492
"""Tests for the TelemetryClient class."""
4593

46-
def setUp(self):
47-
"""Set up test fixtures."""
48-
self.connection_uuid = str(uuid.uuid4())
49-
self.auth_provider = AccessTokenAuthProvider("test-token")
50-
self.host_url = "test-host"
51-
self.executor = MagicMock()
52-
53-
self.client = TelemetryClient(
54-
telemetry_enabled=True,
55-
connection_uuid=self.connection_uuid,
56-
auth_provider=self.auth_provider,
57-
host_url=self.host_url,
58-
executor=self.executor,
59-
)
60-
6194
@patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog")
6295
@patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.getDriverSystemConfiguration")
6396
@patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4")
6497
@patch("databricks.sql.telemetry.telemetry_client.time.time")
65-
def test_export_initial_telemetry_log(self, mock_time, mock_uuid4, mock_get_driver_config, mock_frontend_log):
98+
def test_export_initial_telemetry_log(
99+
self,
100+
mock_time,
101+
mock_uuid4,
102+
mock_get_driver_config,
103+
mock_frontend_log,
104+
telemetry_client_setup
105+
):
66106
"""Test exporting initial telemetry log."""
67107
mock_time.return_value = 1000
68108
mock_uuid4.return_value = "test-uuid"
69109
mock_get_driver_config.return_value = "test-driver-config"
70110
mock_frontend_log.return_value = MagicMock()
71-
72-
self.client.export_event = MagicMock()
111+
112+
client = telemetry_client_setup["client"]
113+
host_url = telemetry_client_setup["host_url"]
114+
client.export_event = MagicMock()
73115

74116
driver_connection_params = DriverConnectionParameters(
75117
http_path="test-path",
76118
mode=DatabricksClientType.THRIFT,
77-
host_info=HostDetails(host_url=self.host_url, port=443),
119+
host_info=HostDetails(host_url=host_url, port=443),
78120
auth_mech=AuthMech.PAT,
79121
auth_flow=None,
80122
)
81123
user_agent = "test-user-agent"
82124

83-
self.client.export_initial_telemetry_log(driver_connection_params, user_agent)
125+
client.export_initial_telemetry_log(driver_connection_params, user_agent)
84126

85127
mock_frontend_log.assert_called_once()
86-
self.client.export_event.assert_called_once_with(mock_frontend_log.return_value)
128+
client.export_event.assert_called_once_with(mock_frontend_log.return_value)
87129

88-
def test_export_event(self):
130+
def test_export_event(self, telemetry_client_setup):
89131
"""Test exporting an event."""
90-
self.client.flush = MagicMock()
132+
client = telemetry_client_setup["client"]
133+
client.flush = MagicMock()
91134

92135
for i in range(5):
93-
self.client.export_event(f"event-{i}")
136+
client.export_event(f"event-{i}")
94137

95-
self.client.flush.assert_not_called()
96-
self.assertEqual(len(self.client._events_batch), 5)
138+
client.flush.assert_not_called()
139+
assert len(client._events_batch) == 5
97140

98141
for i in range(5, 10):
99-
self.client.export_event(f"event-{i}")
142+
client.export_event(f"event-{i}")
100143

101-
self.client.flush.assert_called_once()
102-
self.assertEqual(len(self.client._events_batch), 10)
144+
client.flush.assert_called_once()
145+
assert len(client._events_batch) == 10
103146

104147
@patch("requests.post")
105-
def test_send_telemetry_authenticated(self, mock_post):
148+
def test_send_telemetry_authenticated(self, mock_post, telemetry_client_setup):
106149
"""Test sending telemetry to the server with authentication."""
150+
client = telemetry_client_setup["client"]
151+
executor = telemetry_client_setup["executor"]
152+
107153
events = [MagicMock(), MagicMock()]
108154
events[0].to_json.return_value = '{"event": "1"}'
109155
events[1].to_json.return_value = '{"event": "2"}'
110156

111-
self.client._send_telemetry(events)
157+
client._send_telemetry(events)
112158

113-
self.executor.submit.assert_called_once()
114-
args, kwargs = self.executor.submit.call_args
115-
self.assertEqual(args[0], requests.post)
116-
self.assertEqual(kwargs["timeout"], 10)
117-
self.assertIn("Authorization", kwargs["headers"])
118-
self.assertEqual(kwargs["headers"]["Authorization"], "Bearer test-token")
159+
executor.submit.assert_called_once()
160+
args, kwargs = executor.submit.call_args
161+
assert args[0] == requests.post
162+
assert kwargs["timeout"] == 10
163+
assert "Authorization" in kwargs["headers"]
164+
assert kwargs["headers"]["Authorization"] == "Bearer test-token"
119165

120166
@patch("requests.post")
121-
def test_send_telemetry_unauthenticated(self, mock_post):
167+
def test_send_telemetry_unauthenticated(self, mock_post, telemetry_client_setup):
122168
"""Test sending telemetry to the server without authentication."""
169+
host_url = telemetry_client_setup["host_url"]
170+
executor = telemetry_client_setup["executor"]
171+
123172
unauthenticated_client = TelemetryClient(
124173
telemetry_enabled=True,
125174
connection_uuid=str(uuid.uuid4()),
126175
auth_provider=None, # No auth provider
127-
host_url=self.host_url,
128-
executor=self.executor,
176+
host_url=host_url,
177+
executor=executor,
129178
)
130179

131180
events = [MagicMock(), MagicMock()]
@@ -134,47 +183,43 @@ def test_send_telemetry_unauthenticated(self, mock_post):
134183

135184
unauthenticated_client._send_telemetry(events)
136185

137-
self.executor.submit.assert_called_once()
138-
args, kwargs = self.executor.submit.call_args
139-
self.assertEqual(args[0], requests.post)
140-
self.assertEqual(kwargs["timeout"], 10)
141-
self.assertNotIn("Authorization", kwargs["headers"]) # No auth header
142-
self.assertEqual(kwargs["headers"]["Accept"], "application/json")
143-
self.assertEqual(kwargs["headers"]["Content-Type"], "application/json")
186+
executor.submit.assert_called_once()
187+
args, kwargs = executor.submit.call_args
188+
assert args[0] == requests.post
189+
assert kwargs["timeout"] == 10
190+
assert "Authorization" not in kwargs["headers"] # No auth header
191+
assert kwargs["headers"]["Accept"] == "application/json"
192+
assert kwargs["headers"]["Content-Type"] == "application/json"
144193

145-
def test_flush(self):
194+
def test_flush(self, telemetry_client_setup):
146195
"""Test flushing events."""
147-
self.client._events_batch = ["event1", "event2"]
148-
self.client._send_telemetry = MagicMock()
196+
client = telemetry_client_setup["client"]
197+
client._events_batch = ["event1", "event2"]
198+
client._send_telemetry = MagicMock()
149199

150-
self.client.flush()
200+
client.flush()
151201

152-
self.client._send_telemetry.assert_called_once_with(["event1", "event2"])
153-
self.assertEqual(self.client._events_batch, [])
202+
client._send_telemetry.assert_called_once_with(["event1", "event2"])
203+
assert client._events_batch == []
154204

155205
@patch("databricks.sql.telemetry.telemetry_client.TelemetryClientFactory")
156-
def test_close(self, mock_factory_class):
206+
def test_close(self, mock_factory_class, telemetry_client_setup):
157207
"""Test closing the client."""
158-
self.client.flush = MagicMock()
208+
client = telemetry_client_setup["client"]
209+
connection_uuid = telemetry_client_setup["connection_uuid"]
210+
client.flush = MagicMock()
159211

160-
self.client.close()
212+
client.close()
161213

162-
self.client.flush.assert_called_once()
163-
mock_factory_class.close.assert_called_once_with(self.connection_uuid)
214+
client.flush.assert_called_once()
215+
mock_factory_class.close.assert_called_once_with(connection_uuid)
164216

165217

166-
class TestTelemetryClientFactory(unittest.TestCase):
218+
class TestTelemetryClientFactory:
167219
"""Tests for the TelemetryClientFactory static class."""
168220

169-
def setUp(self):
170-
"""Set up test fixtures."""
171-
# Reset the static class state before each test
172-
TelemetryClientFactory._clients = {}
173-
TelemetryClientFactory._executor = None
174-
TelemetryClientFactory._initialized = False
175-
176221
@patch("databricks.sql.telemetry.telemetry_client.TelemetryClient")
177-
def test_initialize_telemetry_client_enabled(self, mock_client_class):
222+
def test_initialize_telemetry_client_enabled(self, mock_client_class, telemetry_factory_reset):
178223
"""Test initializing a telemetry client when telemetry is enabled."""
179224
connection_uuid = "test-uuid"
180225
auth_provider = MagicMock()
@@ -197,16 +242,16 @@ def test_initialize_telemetry_client_enabled(self, mock_client_class):
197242
host_url=host_url,
198243
executor=TelemetryClientFactory._executor,
199244
)
200-
self.assertEqual(TelemetryClientFactory._clients[connection_uuid], mock_client)
245+
assert TelemetryClientFactory._clients[connection_uuid] == mock_client
201246

202247
# Call again with the same connection_uuid
203248
client2 = TelemetryClientFactory.get_telemetry_client(connection_uuid=connection_uuid)
204249

205250
# Verify the same client was returned and no new client was created
206-
self.assertEqual(client2, mock_client)
251+
assert client2 == mock_client
207252
mock_client_class.assert_called_once() # Still only called once
208253

209-
def test_initialize_telemetry_client_disabled(self):
254+
def test_initialize_telemetry_client_disabled(self, telemetry_factory_reset):
210255
"""Test initializing a telemetry client when telemetry is disabled."""
211256
connection_uuid = "test-uuid"
212257
TelemetryClientFactory.initialize_telemetry_client(
@@ -217,30 +262,30 @@ def test_initialize_telemetry_client_disabled(self):
217262
)
218263

219264
# Verify a NoopTelemetryClient was stored
220-
self.assertIsInstance(TelemetryClientFactory._clients[connection_uuid], NoopTelemetryClient)
265+
assert isinstance(TelemetryClientFactory._clients[connection_uuid], NoopTelemetryClient)
221266

222267
client2 = TelemetryClientFactory.get_telemetry_client(connection_uuid)
223-
self.assertIsInstance(client2, NoopTelemetryClient)
268+
assert isinstance(client2, NoopTelemetryClient)
224269

225-
def test_get_telemetry_client_existing(self):
270+
def test_get_telemetry_client_existing(self, telemetry_factory_reset):
226271
"""Test getting an existing telemetry client."""
227272
connection_uuid = "test-uuid"
228273
mock_client = MagicMock()
229274
TelemetryClientFactory._clients[connection_uuid] = mock_client
230275

231276
client = TelemetryClientFactory.get_telemetry_client(connection_uuid)
232277

233-
self.assertEqual(client, mock_client)
278+
assert client == mock_client
234279

235-
def test_get_telemetry_client_nonexistent(self):
280+
def test_get_telemetry_client_nonexistent(self, telemetry_factory_reset):
236281
"""Test getting a non-existent telemetry client."""
237282
client = TelemetryClientFactory.get_telemetry_client("nonexistent-uuid")
238283

239-
self.assertIsInstance(client, NoopTelemetryClient)
284+
assert isinstance(client, NoopTelemetryClient)
240285

241286
@patch("databricks.sql.telemetry.telemetry_client.ThreadPoolExecutor")
242287
@patch("databricks.sql.telemetry.telemetry_client.TelemetryClient")
243-
def test_close(self, mock_client_class, mock_executor_class):
288+
def test_close(self, mock_client_class, mock_executor_class, telemetry_factory_reset):
244289
"""Test that factory reinitializes properly after complete shutdown."""
245290
connection_uuid1 = "test-uuid1"
246291
mock_executor1 = MagicMock()
@@ -254,9 +299,9 @@ def test_close(self, mock_client_class, mock_executor_class):
254299

255300
TelemetryClientFactory.close(connection_uuid1)
256301

257-
self.assertEqual(TelemetryClientFactory._clients, {})
258-
self.assertIsNone(TelemetryClientFactory._executor)
259-
self.assertFalse(TelemetryClientFactory._initialized)
302+
assert TelemetryClientFactory._clients == {}
303+
assert TelemetryClientFactory._executor is None
304+
assert TelemetryClientFactory._initialized is False
260305
mock_executor1.shutdown.assert_called_once_with(wait=True)
261306

262307
# Now create a new client - this should reinitialize the factory
@@ -274,14 +319,11 @@ def test_close(self, mock_client_class, mock_executor_class):
274319
)
275320

276321
# Verify factory was reinitialized
277-
self.assertTrue(TelemetryClientFactory._initialized)
278-
self.assertIsNotNone(TelemetryClientFactory._executor)
279-
self.assertEqual(TelemetryClientFactory._executor, mock_executor2)
280-
self.assertIn(connection_uuid2, TelemetryClientFactory._clients)
281-
self.assertEqual(TelemetryClientFactory._clients[connection_uuid2], mock_client2)
322+
assert TelemetryClientFactory._initialized is True
323+
assert TelemetryClientFactory._executor is not None
324+
assert TelemetryClientFactory._executor == mock_executor2
325+
assert connection_uuid2 in TelemetryClientFactory._clients
326+
assert TelemetryClientFactory._clients[connection_uuid2] == mock_client2
282327

283328
# Verify new ThreadPoolExecutor was created
284-
self.assertEqual(mock_executor_class.call_count, 1)
285-
286-
if __name__ == "__main__":
287-
unittest.main()
329+
assert mock_executor_class.call_count == 1

0 commit comments

Comments
 (0)