Skip to content

Commit 44fade7

Browse files
committed
unit test for telemetry client
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 1c1c1be commit 44fade7

File tree

1 file changed

+270
-0
lines changed

1 file changed

+270
-0
lines changed

tests/unit/test_telemetry.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
import unittest
2+
import uuid
3+
import requests
4+
from unittest.mock import patch, MagicMock, call
5+
6+
from databricks.sql.telemetry.telemetry_client import (
7+
TelemetryClient,
8+
NoopTelemetryClient,
9+
TelemetryClientFactory,
10+
)
11+
from databricks.sql.telemetry.models.enums import (
12+
AuthMech,
13+
DatabricksClientType,
14+
)
15+
from databricks.sql.telemetry.models.event import (
16+
DriverConnectionParameters,
17+
HostDetails,
18+
)
19+
from databricks.sql.auth.authenticators import (
20+
AccessTokenAuthProvider,
21+
)
22+
23+
class TestNoopTelemetryClient(unittest.TestCase):
24+
"""Tests for the NoopTelemetryClient class."""
25+
26+
def test_singleton(self):
27+
"""Test that NoopTelemetryClient is a singleton."""
28+
client1 = NoopTelemetryClient()
29+
client2 = NoopTelemetryClient()
30+
self.assertIs(client1, client2)
31+
32+
def test_export_initial_telemetry_log(self):
33+
"""Test that export_initial_telemetry_log does nothing."""
34+
client = NoopTelemetryClient()
35+
client.export_initial_telemetry_log()
36+
37+
def test_close(self):
38+
"""Test that close does nothing."""
39+
client = NoopTelemetryClient()
40+
client.close()
41+
42+
43+
class TestTelemetryClient(unittest.TestCase):
44+
"""Tests for the TelemetryClient class."""
45+
46+
def setUp(self):
47+
"""Set up test fixtures."""
48+
self.connection_uuid = str(uuid.uuid4())
49+
self.auth_provider = AccessTokenAuthProvider("test-token")
50+
self.user_agent = "test-user-agent"
51+
self.host_url = "test-host"
52+
self.driver_connection_params = DriverConnectionParameters(
53+
http_path="test-path",
54+
mode=DatabricksClientType.THRIFT,
55+
host_info=HostDetails(host_url=self.host_url, port=443),
56+
auth_mech=AuthMech.PAT,
57+
auth_flow=None,
58+
)
59+
self.executor = MagicMock()
60+
61+
self.client = TelemetryClient(
62+
telemetry_enabled=True,
63+
batch_size=10,
64+
connection_uuid=self.connection_uuid,
65+
auth_provider=self.auth_provider,
66+
user_agent=self.user_agent,
67+
driver_connection_params=self.driver_connection_params,
68+
executor=self.executor,
69+
)
70+
71+
@patch("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog")
72+
@patch("databricks.sql.telemetry.telemetry_client.TelemetryHelper.getDriverSystemConfiguration")
73+
@patch("databricks.sql.telemetry.telemetry_client.uuid.uuid4")
74+
@patch("databricks.sql.telemetry.telemetry_client.time.time")
75+
def test_export_initial_telemetry_log(self, mock_time, mock_uuid4, mock_get_driver_config, mock_frontend_log):
76+
"""Test exporting initial telemetry log."""
77+
mock_time.return_value = 1000
78+
mock_uuid4.return_value = "test-uuid"
79+
mock_get_driver_config.return_value = "test-driver-config"
80+
mock_frontend_log.return_value = MagicMock()
81+
82+
self.client.export_event = MagicMock()
83+
84+
self.client.export_initial_telemetry_log()
85+
86+
mock_frontend_log.assert_called_once()
87+
self.client.export_event.assert_called_once_with(mock_frontend_log.return_value)
88+
89+
def test_export_event(self):
90+
"""Test exporting an event."""
91+
self.client.flush = MagicMock()
92+
93+
for i in range(5):
94+
self.client.export_event(f"event-{i}")
95+
96+
self.client.flush.assert_not_called()
97+
self.assertEqual(len(self.client.events_batch), 5)
98+
99+
for i in range(5, 10):
100+
self.client.export_event(f"event-{i}")
101+
102+
self.client.flush.assert_called_once()
103+
self.assertEqual(len(self.client.events_batch), 10)
104+
105+
@patch("requests.post")
106+
def test_send_telemetry_authenticated(self, mock_post):
107+
"""Test sending telemetry to the server with authentication."""
108+
events = [MagicMock(), MagicMock()]
109+
events[0].to_json.return_value = '{"event": "1"}'
110+
events[1].to_json.return_value = '{"event": "2"}'
111+
112+
self.client._send_telemetry(events)
113+
114+
self.executor.submit.assert_called_once()
115+
args, kwargs = self.executor.submit.call_args
116+
self.assertEqual(args[0], requests.post)
117+
self.assertEqual(kwargs["timeout"], 10)
118+
self.assertIn("Authorization", kwargs["headers"])
119+
self.assertEqual(kwargs["headers"]["Authorization"], "Bearer test-token")
120+
121+
@patch("requests.post")
122+
def test_send_telemetry_unauthenticated(self, mock_post):
123+
"""Test sending telemetry to the server without authentication."""
124+
unauthenticated_client = TelemetryClient(
125+
telemetry_enabled=True,
126+
batch_size=10,
127+
connection_uuid=str(uuid.uuid4()),
128+
auth_provider=None, # No auth provider
129+
user_agent=self.user_agent,
130+
driver_connection_params=self.driver_connection_params,
131+
executor=self.executor,
132+
)
133+
134+
events = [MagicMock(), MagicMock()]
135+
events[0].to_json.return_value = '{"event": "1"}'
136+
events[1].to_json.return_value = '{"event": "2"}'
137+
138+
unauthenticated_client._send_telemetry(events)
139+
140+
self.executor.submit.assert_called_once()
141+
args, kwargs = self.executor.submit.call_args
142+
self.assertEqual(args[0], requests.post)
143+
self.assertEqual(kwargs["timeout"], 10)
144+
self.assertNotIn("Authorization", kwargs["headers"]) # No auth header
145+
self.assertEqual(kwargs["headers"]["Accept"], "application/json")
146+
self.assertEqual(kwargs["headers"]["Content-Type"], "application/json")
147+
148+
def test_flush(self):
149+
"""Test flushing events."""
150+
self.client.events_batch = ["event1", "event2"]
151+
self.client._send_telemetry = MagicMock()
152+
153+
self.client.flush()
154+
155+
self.client._send_telemetry.assert_called_once_with(["event1", "event2"])
156+
self.assertEqual(self.client.events_batch, [])
157+
158+
@patch("databricks.sql.telemetry.telemetry_client.telemetry_client_factory")
159+
def test_close(self, mock_factory):
160+
"""Test closing the client."""
161+
self.client.flush = MagicMock()
162+
163+
self.client.close()
164+
165+
self.client.flush.assert_called_once()
166+
mock_factory.close.assert_called_once_with(self.connection_uuid)
167+
168+
169+
class TestTelemetryClientFactory(unittest.TestCase):
170+
"""Tests for the TelemetryClientFactory class."""
171+
172+
def setUp(self):
173+
"""Set up test fixtures."""
174+
TelemetryClientFactory._instance = None
175+
self.factory = TelemetryClientFactory()
176+
177+
def test_singleton(self):
178+
"""Test that TelemetryClientFactory is a singleton."""
179+
factory1 = TelemetryClientFactory()
180+
factory2 = TelemetryClientFactory()
181+
self.assertIs(factory1, factory2)
182+
183+
@patch("databricks.sql.telemetry.telemetry_client.TelemetryClient")
184+
def test_get_telemetry_client_enabled(self, mock_client_class):
185+
"""Test getting a telemetry client when telemetry is enabled."""
186+
connection_uuid = "test-uuid"
187+
auth_provider = MagicMock()
188+
user_agent = "test-user-agent"
189+
driver_connection_params = MagicMock()
190+
mock_client = MagicMock()
191+
mock_client_class.return_value = mock_client
192+
193+
client = self.factory.get_telemetry_client(
194+
telemetry_enabled=True,
195+
batch_size=10,
196+
connection_uuid=connection_uuid,
197+
auth_provider=auth_provider,
198+
user_agent=user_agent,
199+
driver_connection_params=driver_connection_params,
200+
)
201+
202+
# Verify a new client was created and stored
203+
mock_client_class.assert_called_once_with(
204+
telemetry_enabled=True,
205+
batch_size=10,
206+
connection_uuid=connection_uuid,
207+
auth_provider=auth_provider,
208+
user_agent=user_agent,
209+
driver_connection_params=driver_connection_params,
210+
executor=self.factory.executor,
211+
)
212+
self.assertEqual(client, mock_client)
213+
self.assertEqual(self.factory._clients[connection_uuid], mock_client)
214+
215+
# Call again with the same connection_uuid
216+
client2 = self.factory.get_telemetry_client(
217+
telemetry_enabled=True,
218+
batch_size=10,
219+
connection_uuid=connection_uuid,
220+
auth_provider=auth_provider,
221+
user_agent=user_agent,
222+
driver_connection_params=driver_connection_params,
223+
)
224+
225+
# Verify the same client was returned and no new client was created
226+
self.assertEqual(client2, mock_client)
227+
mock_client_class.assert_called_once() # Still only called once
228+
229+
def test_get_telemetry_client_disabled(self):
230+
"""Test getting a telemetry client when telemetry is disabled."""
231+
client = self.factory.get_telemetry_client(
232+
telemetry_enabled=False,
233+
batch_size=10,
234+
connection_uuid="test-uuid",
235+
auth_provider=MagicMock(),
236+
user_agent="test-user-agent",
237+
driver_connection_params=MagicMock(),
238+
)
239+
240+
# Verify a NoopTelemetryClient was returned
241+
self.assertIsInstance(client, NoopTelemetryClient)
242+
self.assertEqual(self.factory._clients, {}) # No client was stored
243+
244+
@patch("databricks.sql.telemetry.telemetry_client.ThreadPoolExecutor")
245+
def test_close(self, mock_executor_class):
246+
"""Test closing a client."""
247+
connection_uuid = "test-uuid"
248+
self.factory._clients[connection_uuid] = MagicMock()
249+
mock_executor = MagicMock()
250+
self.factory.executor = mock_executor
251+
252+
self.factory.close(connection_uuid)
253+
254+
# Verify the client was removed
255+
self.assertEqual(self.factory._clients, {})
256+
257+
# Verify the executor was shutdown
258+
mock_executor.shutdown.assert_called_once_with(wait=True)
259+
260+
# Add another client and close it
261+
connection_uuid2 = "test-uuid2"
262+
self.factory._clients[connection_uuid2] = MagicMock()
263+
self.factory.close(connection_uuid2)
264+
265+
# Verify the executor was shutdown again
266+
self.assertEqual(mock_executor.shutdown.call_count, 2)
267+
268+
269+
if __name__ == "__main__":
270+
unittest.main()

0 commit comments

Comments
 (0)