1+ import unittest
2+ import uuid
3+ import requests
4+ from unittest .mock import patch , MagicMock , call
5+
6+ from databricks .sql .telemetry .telemetry_client import (
7+ TelemetryClient ,
8+ NoopTelemetryClient ,
9+ TelemetryClientFactory ,
10+ )
11+ from databricks .sql .telemetry .models .enums import (
12+ AuthMech ,
13+ DatabricksClientType ,
14+ )
15+ from databricks .sql .telemetry .models .event import (
16+ DriverConnectionParameters ,
17+ HostDetails ,
18+ )
19+ from databricks .sql .auth .authenticators import (
20+ AccessTokenAuthProvider ,
21+ )
22+
23+ class TestNoopTelemetryClient (unittest .TestCase ):
24+ """Tests for the NoopTelemetryClient class."""
25+
26+ def test_singleton (self ):
27+ """Test that NoopTelemetryClient is a singleton."""
28+ client1 = NoopTelemetryClient ()
29+ client2 = NoopTelemetryClient ()
30+ self .assertIs (client1 , client2 )
31+
32+ def test_export_initial_telemetry_log (self ):
33+ """Test that export_initial_telemetry_log does nothing."""
34+ client = NoopTelemetryClient ()
35+ client .export_initial_telemetry_log ()
36+
37+ def test_close (self ):
38+ """Test that close does nothing."""
39+ client = NoopTelemetryClient ()
40+ client .close ()
41+
42+
43+ class TestTelemetryClient (unittest .TestCase ):
44+ """Tests for the TelemetryClient class."""
45+
46+ def setUp (self ):
47+ """Set up test fixtures."""
48+ self .connection_uuid = str (uuid .uuid4 ())
49+ self .auth_provider = AccessTokenAuthProvider ("test-token" )
50+ self .user_agent = "test-user-agent"
51+ self .host_url = "test-host"
52+ self .driver_connection_params = DriverConnectionParameters (
53+ http_path = "test-path" ,
54+ mode = DatabricksClientType .THRIFT ,
55+ host_info = HostDetails (host_url = self .host_url , port = 443 ),
56+ auth_mech = AuthMech .PAT ,
57+ auth_flow = None ,
58+ )
59+ self .executor = MagicMock ()
60+
61+ self .client = TelemetryClient (
62+ telemetry_enabled = True ,
63+ batch_size = 10 ,
64+ connection_uuid = self .connection_uuid ,
65+ auth_provider = self .auth_provider ,
66+ user_agent = self .user_agent ,
67+ driver_connection_params = self .driver_connection_params ,
68+ executor = self .executor ,
69+ )
70+
71+ @patch ("databricks.sql.telemetry.telemetry_client.TelemetryFrontendLog" )
72+ @patch ("databricks.sql.telemetry.telemetry_client.TelemetryHelper.getDriverSystemConfiguration" )
73+ @patch ("databricks.sql.telemetry.telemetry_client.uuid.uuid4" )
74+ @patch ("databricks.sql.telemetry.telemetry_client.time.time" )
75+ def test_export_initial_telemetry_log (self , mock_time , mock_uuid4 , mock_get_driver_config , mock_frontend_log ):
76+ """Test exporting initial telemetry log."""
77+ mock_time .return_value = 1000
78+ mock_uuid4 .return_value = "test-uuid"
79+ mock_get_driver_config .return_value = "test-driver-config"
80+ mock_frontend_log .return_value = MagicMock ()
81+
82+ self .client .export_event = MagicMock ()
83+
84+ self .client .export_initial_telemetry_log ()
85+
86+ mock_frontend_log .assert_called_once ()
87+ self .client .export_event .assert_called_once_with (mock_frontend_log .return_value )
88+
89+ def test_export_event (self ):
90+ """Test exporting an event."""
91+ self .client .flush = MagicMock ()
92+
93+ for i in range (5 ):
94+ self .client .export_event (f"event-{ i } " )
95+
96+ self .client .flush .assert_not_called ()
97+ self .assertEqual (len (self .client .events_batch ), 5 )
98+
99+ for i in range (5 , 10 ):
100+ self .client .export_event (f"event-{ i } " )
101+
102+ self .client .flush .assert_called_once ()
103+ self .assertEqual (len (self .client .events_batch ), 10 )
104+
105+ @patch ("requests.post" )
106+ def test_send_telemetry_authenticated (self , mock_post ):
107+ """Test sending telemetry to the server with authentication."""
108+ events = [MagicMock (), MagicMock ()]
109+ events [0 ].to_json .return_value = '{"event": "1"}'
110+ events [1 ].to_json .return_value = '{"event": "2"}'
111+
112+ self .client ._send_telemetry (events )
113+
114+ self .executor .submit .assert_called_once ()
115+ args , kwargs = self .executor .submit .call_args
116+ self .assertEqual (args [0 ], requests .post )
117+ self .assertEqual (kwargs ["timeout" ], 10 )
118+ self .assertIn ("Authorization" , kwargs ["headers" ])
119+ self .assertEqual (kwargs ["headers" ]["Authorization" ], "Bearer test-token" )
120+
121+ @patch ("requests.post" )
122+ def test_send_telemetry_unauthenticated (self , mock_post ):
123+ """Test sending telemetry to the server without authentication."""
124+ unauthenticated_client = TelemetryClient (
125+ telemetry_enabled = True ,
126+ batch_size = 10 ,
127+ connection_uuid = str (uuid .uuid4 ()),
128+ auth_provider = None , # No auth provider
129+ user_agent = self .user_agent ,
130+ driver_connection_params = self .driver_connection_params ,
131+ executor = self .executor ,
132+ )
133+
134+ events = [MagicMock (), MagicMock ()]
135+ events [0 ].to_json .return_value = '{"event": "1"}'
136+ events [1 ].to_json .return_value = '{"event": "2"}'
137+
138+ unauthenticated_client ._send_telemetry (events )
139+
140+ self .executor .submit .assert_called_once ()
141+ args , kwargs = self .executor .submit .call_args
142+ self .assertEqual (args [0 ], requests .post )
143+ self .assertEqual (kwargs ["timeout" ], 10 )
144+ self .assertNotIn ("Authorization" , kwargs ["headers" ]) # No auth header
145+ self .assertEqual (kwargs ["headers" ]["Accept" ], "application/json" )
146+ self .assertEqual (kwargs ["headers" ]["Content-Type" ], "application/json" )
147+
148+ def test_flush (self ):
149+ """Test flushing events."""
150+ self .client .events_batch = ["event1" , "event2" ]
151+ self .client ._send_telemetry = MagicMock ()
152+
153+ self .client .flush ()
154+
155+ self .client ._send_telemetry .assert_called_once_with (["event1" , "event2" ])
156+ self .assertEqual (self .client .events_batch , [])
157+
158+ @patch ("databricks.sql.telemetry.telemetry_client.telemetry_client_factory" )
159+ def test_close (self , mock_factory ):
160+ """Test closing the client."""
161+ self .client .flush = MagicMock ()
162+
163+ self .client .close ()
164+
165+ self .client .flush .assert_called_once ()
166+ mock_factory .close .assert_called_once_with (self .connection_uuid )
167+
168+
169+ class TestTelemetryClientFactory (unittest .TestCase ):
170+ """Tests for the TelemetryClientFactory class."""
171+
172+ def setUp (self ):
173+ """Set up test fixtures."""
174+ TelemetryClientFactory ._instance = None
175+ self .factory = TelemetryClientFactory ()
176+
177+ def test_singleton (self ):
178+ """Test that TelemetryClientFactory is a singleton."""
179+ factory1 = TelemetryClientFactory ()
180+ factory2 = TelemetryClientFactory ()
181+ self .assertIs (factory1 , factory2 )
182+
183+ @patch ("databricks.sql.telemetry.telemetry_client.TelemetryClient" )
184+ def test_get_telemetry_client_enabled (self , mock_client_class ):
185+ """Test getting a telemetry client when telemetry is enabled."""
186+ connection_uuid = "test-uuid"
187+ auth_provider = MagicMock ()
188+ user_agent = "test-user-agent"
189+ driver_connection_params = MagicMock ()
190+ mock_client = MagicMock ()
191+ mock_client_class .return_value = mock_client
192+
193+ client = self .factory .get_telemetry_client (
194+ telemetry_enabled = True ,
195+ batch_size = 10 ,
196+ connection_uuid = connection_uuid ,
197+ auth_provider = auth_provider ,
198+ user_agent = user_agent ,
199+ driver_connection_params = driver_connection_params ,
200+ )
201+
202+ # Verify a new client was created and stored
203+ mock_client_class .assert_called_once_with (
204+ telemetry_enabled = True ,
205+ batch_size = 10 ,
206+ connection_uuid = connection_uuid ,
207+ auth_provider = auth_provider ,
208+ user_agent = user_agent ,
209+ driver_connection_params = driver_connection_params ,
210+ executor = self .factory .executor ,
211+ )
212+ self .assertEqual (client , mock_client )
213+ self .assertEqual (self .factory ._clients [connection_uuid ], mock_client )
214+
215+ # Call again with the same connection_uuid
216+ client2 = self .factory .get_telemetry_client (
217+ telemetry_enabled = True ,
218+ batch_size = 10 ,
219+ connection_uuid = connection_uuid ,
220+ auth_provider = auth_provider ,
221+ user_agent = user_agent ,
222+ driver_connection_params = driver_connection_params ,
223+ )
224+
225+ # Verify the same client was returned and no new client was created
226+ self .assertEqual (client2 , mock_client )
227+ mock_client_class .assert_called_once () # Still only called once
228+
229+ def test_get_telemetry_client_disabled (self ):
230+ """Test getting a telemetry client when telemetry is disabled."""
231+ client = self .factory .get_telemetry_client (
232+ telemetry_enabled = False ,
233+ batch_size = 10 ,
234+ connection_uuid = "test-uuid" ,
235+ auth_provider = MagicMock (),
236+ user_agent = "test-user-agent" ,
237+ driver_connection_params = MagicMock (),
238+ )
239+
240+ # Verify a NoopTelemetryClient was returned
241+ self .assertIsInstance (client , NoopTelemetryClient )
242+ self .assertEqual (self .factory ._clients , {}) # No client was stored
243+
244+ @patch ("databricks.sql.telemetry.telemetry_client.ThreadPoolExecutor" )
245+ def test_close (self , mock_executor_class ):
246+ """Test closing a client."""
247+ connection_uuid = "test-uuid"
248+ self .factory ._clients [connection_uuid ] = MagicMock ()
249+ mock_executor = MagicMock ()
250+ self .factory .executor = mock_executor
251+
252+ self .factory .close (connection_uuid )
253+
254+ # Verify the client was removed
255+ self .assertEqual (self .factory ._clients , {})
256+
257+ # Verify the executor was shutdown
258+ mock_executor .shutdown .assert_called_once_with (wait = True )
259+
260+ # Add another client and close it
261+ connection_uuid2 = "test-uuid2"
262+ self .factory ._clients [connection_uuid2 ] = MagicMock ()
263+ self .factory .close (connection_uuid2 )
264+
265+ # Verify the executor was shutdown again
266+ self .assertEqual (mock_executor .shutdown .call_count , 2 )
267+
268+
269+ if __name__ == "__main__" :
270+ unittest .main ()
0 commit comments