1- import unittest
21import uuid
2+ import pytest
33import requests
44from unittest .mock import patch , MagicMock , call
55
2020 AccessTokenAuthProvider ,
2121)
2222
23- class TestNoopTelemetryClient (unittest .TestCase ):
23+
24+ @pytest .fixture
25+ def noop_telemetry_client ():
26+ """Fixture for NoopTelemetryClient."""
27+ return NoopTelemetryClient ()
28+
29+
30+ @pytest .fixture
31+ def telemetry_client_setup ():
32+ """Fixture for TelemetryClient setup data."""
33+ connection_uuid = str (uuid .uuid4 ())
34+ auth_provider = AccessTokenAuthProvider ("test-token" )
35+ host_url = "test-host"
36+ executor = MagicMock ()
37+
38+ client = TelemetryClient (
39+ telemetry_enabled = True ,
40+ connection_uuid = connection_uuid ,
41+ auth_provider = auth_provider ,
42+ host_url = host_url ,
43+ executor = executor ,
44+ )
45+
46+ return {
47+ "client" : client ,
48+ "connection_uuid" : connection_uuid ,
49+ "auth_provider" : auth_provider ,
50+ "host_url" : host_url ,
51+ "executor" : executor ,
52+ }
53+
54+
55+ @pytest .fixture
56+ def telemetry_factory_reset ():
57+ """Fixture to reset TelemetryClientFactory state before each test."""
58+ # Reset the static class state before each test
59+ TelemetryClientFactory ._clients = {}
60+ TelemetryClientFactory ._executor = None
61+ TelemetryClientFactory ._initialized = False
62+ yield
63+ # Cleanup after test if needed
64+ TelemetryClientFactory ._clients = {}
65+ if TelemetryClientFactory ._executor :
66+ TelemetryClientFactory ._executor .shutdown (wait = True )
67+ TelemetryClientFactory ._executor = None
68+ TelemetryClientFactory ._initialized = False
69+
70+
71+ class TestNoopTelemetryClient :
2472 """Tests for the NoopTelemetryClient class."""
2573
2674 def test_singleton (self ):
2775 """Test that NoopTelemetryClient is a singleton."""
2876 client1 = NoopTelemetryClient ()
2977 client2 = NoopTelemetryClient ()
30- self . assertIs ( client1 , client2 )
78+ assert client1 is client2
3179
32- def test_export_initial_telemetry_log (self ):
80+ def test_export_initial_telemetry_log (self , noop_telemetry_client ):
3381 """Test that export_initial_telemetry_log does nothing."""
34- client = NoopTelemetryClient ()
35- client .export_initial_telemetry_log (driver_connection_params = MagicMock (), user_agent = "test" )
82+ noop_telemetry_client .export_initial_telemetry_log (
83+ driver_connection_params = MagicMock (), user_agent = "test"
84+ )
3685
37- def test_close (self ):
86+ def test_close (self , noop_telemetry_client ):
3887 """Test that close does nothing."""
39- client = NoopTelemetryClient ()
40- client .close ()
88+ noop_telemetry_client .close ()
4189
4290
43- class TestTelemetryClient ( unittest . TestCase ) :
91+ class TestTelemetryClient :
4492 """Tests for the TelemetryClient class."""
4593
46- def setUp (self ):
47- """Set up test fixtures."""
48- self .connection_uuid = str (uuid .uuid4 ())
49- self .auth_provider = AccessTokenAuthProvider ("test-token" )
50- self .host_url = "test-host"
51- self .executor = MagicMock ()
52-
53- self .client = TelemetryClient (
54- telemetry_enabled = True ,
55- connection_uuid = self .connection_uuid ,
56- auth_provider = self .auth_provider ,
57- host_url = self .host_url ,
58- executor = self .executor ,
59- )
60-
6194 @patch ("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog" )
6295 @patch ("databricks.sql.telemetry.telemetry_client.TelemetryHelper.getDriverSystemConfiguration" )
6396 @patch ("databricks.sql.telemetry.telemetry_client.uuid.uuid4" )
6497 @patch ("databricks.sql.telemetry.telemetry_client.time.time" )
65- def test_export_initial_telemetry_log (self , mock_time , mock_uuid4 , mock_get_driver_config , mock_frontend_log ):
98+ def test_export_initial_telemetry_log (
99+ self ,
100+ mock_time ,
101+ mock_uuid4 ,
102+ mock_get_driver_config ,
103+ mock_frontend_log ,
104+ telemetry_client_setup
105+ ):
66106 """Test exporting initial telemetry log."""
67107 mock_time .return_value = 1000
68108 mock_uuid4 .return_value = "test-uuid"
69109 mock_get_driver_config .return_value = "test-driver-config"
70110 mock_frontend_log .return_value = MagicMock ()
71-
72- self .client .export_event = MagicMock ()
111+
112+ client = telemetry_client_setup ["client" ]
113+ host_url = telemetry_client_setup ["host_url" ]
114+ client .export_event = MagicMock ()
73115
74116 driver_connection_params = DriverConnectionParameters (
75117 http_path = "test-path" ,
76118 mode = DatabricksClientType .THRIFT ,
77- host_info = HostDetails (host_url = self . host_url , port = 443 ),
119+ host_info = HostDetails (host_url = host_url , port = 443 ),
78120 auth_mech = AuthMech .PAT ,
79121 auth_flow = None ,
80122 )
81123 user_agent = "test-user-agent"
82124
83- self . client .export_initial_telemetry_log (driver_connection_params , user_agent )
125+ client .export_initial_telemetry_log (driver_connection_params , user_agent )
84126
85127 mock_frontend_log .assert_called_once ()
86- self . client .export_event .assert_called_once_with (mock_frontend_log .return_value )
128+ client .export_event .assert_called_once_with (mock_frontend_log .return_value )
87129
88- def test_export_event (self ):
130+ def test_export_event (self , telemetry_client_setup ):
89131 """Test exporting an event."""
90- self .client .flush = MagicMock ()
132+ client = telemetry_client_setup ["client" ]
133+ client .flush = MagicMock ()
91134
92135 for i in range (5 ):
93- self . client .export_event (f"event-{ i } " )
136+ client .export_event (f"event-{ i } " )
94137
95- self . client .flush .assert_not_called ()
96- self . assertEqual ( len (self . client ._events_batch ), 5 )
138+ client .flush .assert_not_called ()
139+ assert len (client ._events_batch ) == 5
97140
98141 for i in range (5 , 10 ):
99- self . client .export_event (f"event-{ i } " )
142+ client .export_event (f"event-{ i } " )
100143
101- self . client .flush .assert_called_once ()
102- self . assertEqual ( len (self . client ._events_batch ), 10 )
144+ client .flush .assert_called_once ()
145+ assert len (client ._events_batch ) == 10
103146
104147 @patch ("requests.post" )
105- def test_send_telemetry_authenticated (self , mock_post ):
148+ def test_send_telemetry_authenticated (self , mock_post , telemetry_client_setup ):
106149 """Test sending telemetry to the server with authentication."""
150+ client = telemetry_client_setup ["client" ]
151+ executor = telemetry_client_setup ["executor" ]
152+
107153 events = [MagicMock (), MagicMock ()]
108154 events [0 ].to_json .return_value = '{"event": "1"}'
109155 events [1 ].to_json .return_value = '{"event": "2"}'
110156
111- self . client ._send_telemetry (events )
157+ client ._send_telemetry (events )
112158
113- self . executor .submit .assert_called_once ()
114- args , kwargs = self . executor .submit .call_args
115- self . assertEqual ( args [0 ], requests .post )
116- self . assertEqual ( kwargs ["timeout" ], 10 )
117- self . assertIn ( "Authorization" , kwargs ["headers" ])
118- self . assertEqual ( kwargs ["headers" ]["Authorization" ], "Bearer test-token" )
159+ executor .submit .assert_called_once ()
160+ args , kwargs = executor .submit .call_args
161+ assert args [0 ] == requests .post
162+ assert kwargs ["timeout" ] == 10
163+ assert "Authorization" in kwargs ["headers" ]
164+ assert kwargs ["headers" ]["Authorization" ] == "Bearer test-token"
119165
120166 @patch ("requests.post" )
121- def test_send_telemetry_unauthenticated (self , mock_post ):
167+ def test_send_telemetry_unauthenticated (self , mock_post , telemetry_client_setup ):
122168 """Test sending telemetry to the server without authentication."""
169+ host_url = telemetry_client_setup ["host_url" ]
170+ executor = telemetry_client_setup ["executor" ]
171+
123172 unauthenticated_client = TelemetryClient (
124173 telemetry_enabled = True ,
125174 connection_uuid = str (uuid .uuid4 ()),
126175 auth_provider = None , # No auth provider
127- host_url = self . host_url ,
128- executor = self . executor ,
176+ host_url = host_url ,
177+ executor = executor ,
129178 )
130179
131180 events = [MagicMock (), MagicMock ()]
@@ -134,47 +183,43 @@ def test_send_telemetry_unauthenticated(self, mock_post):
134183
135184 unauthenticated_client ._send_telemetry (events )
136185
137- self . executor .submit .assert_called_once ()
138- args , kwargs = self . executor .submit .call_args
139- self . assertEqual ( args [0 ], requests .post )
140- self . assertEqual ( kwargs ["timeout" ], 10 )
141- self . assertNotIn ( "Authorization" , kwargs ["headers" ]) # No auth header
142- self . assertEqual ( kwargs ["headers" ]["Accept" ], "application/json" )
143- self . assertEqual ( kwargs ["headers" ]["Content-Type" ], "application/json" )
186+ executor .submit .assert_called_once ()
187+ args , kwargs = executor .submit .call_args
188+ assert args [0 ] == requests .post
189+ assert kwargs ["timeout" ] == 10
190+ assert "Authorization" not in kwargs ["headers" ] # No auth header
191+ assert kwargs ["headers" ]["Accept" ] == "application/json"
192+ assert kwargs ["headers" ]["Content-Type" ] == "application/json"
144193
145- def test_flush (self ):
194+ def test_flush (self , telemetry_client_setup ):
146195 """Test flushing events."""
147- self .client ._events_batch = ["event1" , "event2" ]
148- self .client ._send_telemetry = MagicMock ()
196+ client = telemetry_client_setup ["client" ]
197+ client ._events_batch = ["event1" , "event2" ]
198+ client ._send_telemetry = MagicMock ()
149199
150- self . client .flush ()
200+ client .flush ()
151201
152- self . client ._send_telemetry .assert_called_once_with (["event1" , "event2" ])
153- self . assertEqual ( self . client ._events_batch , [])
202+ client ._send_telemetry .assert_called_once_with (["event1" , "event2" ])
203+ assert client ._events_batch == []
154204
155205 @patch ("databricks.sql.telemetry.telemetry_client.TelemetryClientFactory" )
156- def test_close (self , mock_factory_class ):
206+ def test_close (self , mock_factory_class , telemetry_client_setup ):
157207 """Test closing the client."""
158- self .client .flush = MagicMock ()
208+ client = telemetry_client_setup ["client" ]
209+ connection_uuid = telemetry_client_setup ["connection_uuid" ]
210+ client .flush = MagicMock ()
159211
160- self . client .close ()
212+ client .close ()
161213
162- self . client .flush .assert_called_once ()
163- mock_factory_class .close .assert_called_once_with (self . connection_uuid )
214+ client .flush .assert_called_once ()
215+ mock_factory_class .close .assert_called_once_with (connection_uuid )
164216
165217
166- class TestTelemetryClientFactory ( unittest . TestCase ) :
218+ class TestTelemetryClientFactory :
167219 """Tests for the TelemetryClientFactory static class."""
168220
169- def setUp (self ):
170- """Set up test fixtures."""
171- # Reset the static class state before each test
172- TelemetryClientFactory ._clients = {}
173- TelemetryClientFactory ._executor = None
174- TelemetryClientFactory ._initialized = False
175-
176221 @patch ("databricks.sql.telemetry.telemetry_client.TelemetryClient" )
177- def test_initialize_telemetry_client_enabled (self , mock_client_class ):
222+ def test_initialize_telemetry_client_enabled (self , mock_client_class , telemetry_factory_reset ):
178223 """Test initializing a telemetry client when telemetry is enabled."""
179224 connection_uuid = "test-uuid"
180225 auth_provider = MagicMock ()
@@ -197,16 +242,16 @@ def test_initialize_telemetry_client_enabled(self, mock_client_class):
197242 host_url = host_url ,
198243 executor = TelemetryClientFactory ._executor ,
199244 )
200- self . assertEqual ( TelemetryClientFactory ._clients [connection_uuid ], mock_client )
245+ assert TelemetryClientFactory ._clients [connection_uuid ] == mock_client
201246
202247 # Call again with the same connection_uuid
203248 client2 = TelemetryClientFactory .get_telemetry_client (connection_uuid = connection_uuid )
204249
205250 # Verify the same client was returned and no new client was created
206- self . assertEqual ( client2 , mock_client )
251+ assert client2 == mock_client
207252 mock_client_class .assert_called_once () # Still only called once
208253
209- def test_initialize_telemetry_client_disabled (self ):
254+ def test_initialize_telemetry_client_disabled (self , telemetry_factory_reset ):
210255 """Test initializing a telemetry client when telemetry is disabled."""
211256 connection_uuid = "test-uuid"
212257 TelemetryClientFactory .initialize_telemetry_client (
@@ -217,30 +262,30 @@ def test_initialize_telemetry_client_disabled(self):
217262 )
218263
219264 # Verify a NoopTelemetryClient was stored
220- self . assertIsInstance (TelemetryClientFactory ._clients [connection_uuid ], NoopTelemetryClient )
265+ assert isinstance (TelemetryClientFactory ._clients [connection_uuid ], NoopTelemetryClient )
221266
222267 client2 = TelemetryClientFactory .get_telemetry_client (connection_uuid )
223- self . assertIsInstance (client2 , NoopTelemetryClient )
268+ assert isinstance (client2 , NoopTelemetryClient )
224269
225- def test_get_telemetry_client_existing (self ):
270+ def test_get_telemetry_client_existing (self , telemetry_factory_reset ):
226271 """Test getting an existing telemetry client."""
227272 connection_uuid = "test-uuid"
228273 mock_client = MagicMock ()
229274 TelemetryClientFactory ._clients [connection_uuid ] = mock_client
230275
231276 client = TelemetryClientFactory .get_telemetry_client (connection_uuid )
232277
233- self . assertEqual ( client , mock_client )
278+ assert client == mock_client
234279
235- def test_get_telemetry_client_nonexistent (self ):
280+ def test_get_telemetry_client_nonexistent (self , telemetry_factory_reset ):
236281 """Test getting a non-existent telemetry client."""
237282 client = TelemetryClientFactory .get_telemetry_client ("nonexistent-uuid" )
238283
239- self . assertIsInstance (client , NoopTelemetryClient )
284+ assert isinstance (client , NoopTelemetryClient )
240285
241286 @patch ("databricks.sql.telemetry.telemetry_client.ThreadPoolExecutor" )
242287 @patch ("databricks.sql.telemetry.telemetry_client.TelemetryClient" )
243- def test_close (self , mock_client_class , mock_executor_class ):
288+ def test_close (self , mock_client_class , mock_executor_class , telemetry_factory_reset ):
244289 """Test that factory reinitializes properly after complete shutdown."""
245290 connection_uuid1 = "test-uuid1"
246291 mock_executor1 = MagicMock ()
@@ -254,9 +299,9 @@ def test_close(self, mock_client_class, mock_executor_class):
254299
255300 TelemetryClientFactory .close (connection_uuid1 )
256301
257- self . assertEqual ( TelemetryClientFactory ._clients , {})
258- self . assertIsNone ( TelemetryClientFactory ._executor )
259- self . assertFalse ( TelemetryClientFactory ._initialized )
302+ assert TelemetryClientFactory ._clients == {}
303+ assert TelemetryClientFactory ._executor is None
304+ assert TelemetryClientFactory ._initialized is False
260305 mock_executor1 .shutdown .assert_called_once_with (wait = True )
261306
262307 # Now create a new client - this should reinitialize the factory
@@ -274,14 +319,11 @@ def test_close(self, mock_client_class, mock_executor_class):
274319 )
275320
276321 # Verify factory was reinitialized
277- self . assertTrue ( TelemetryClientFactory ._initialized )
278- self . assertIsNotNone ( TelemetryClientFactory ._executor )
279- self . assertEqual ( TelemetryClientFactory ._executor , mock_executor2 )
280- self . assertIn ( connection_uuid2 , TelemetryClientFactory ._clients )
281- self . assertEqual ( TelemetryClientFactory ._clients [connection_uuid2 ], mock_client2 )
322+ assert TelemetryClientFactory ._initialized is True
323+ assert TelemetryClientFactory ._executor is not None
324+ assert TelemetryClientFactory ._executor == mock_executor2
325+ assert connection_uuid2 in TelemetryClientFactory ._clients
326+ assert TelemetryClientFactory ._clients [connection_uuid2 ] == mock_client2
282327
283328 # Verify new ThreadPoolExecutor was created
284- self .assertEqual (mock_executor_class .call_count , 1 )
285-
286- if __name__ == "__main__" :
287- unittest .main ()
329+ assert mock_executor_class .call_count == 1
0 commit comments