Skip to content

Commit 8a871e3

Browse files
committed
modified get_telemetry_client and some comments
Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent 44fade7 commit 8a871e3

File tree

5 files changed

+38
-47
lines changed

5 files changed

+38
-47
lines changed

src/databricks/sql/client.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,6 @@ def read(self) -> Optional[OAuthToken]:
247247
self.telemetry_enabled = (
248248
self.client_telemetry_enabled and self.server_telemetry_enabled
249249
)
250-
telemetry_batch_size = kwargs.get(
251-
"telemetry_batch_size", 100
252-
) # TODO: Decide on batch size
253250

254251
user_agent_entry = kwargs.get("user_agent_entry")
255252
if user_agent_entry is None:
@@ -315,9 +312,8 @@ def read(self) -> Optional[OAuthToken]:
315312
discovery_url=TelemetryHelper.get_discovery_url(auth_provider),
316313
socket_timeout=kwargs.get("_socket_timeout", None),
317314
)
318-
self.telemetry_client = telemetry_client_factory.get_telemetry_client(
315+
self.telemetry_client = telemetry_client_factory.initialize_telemetry_client(
319316
telemetry_enabled=self.telemetry_enabled,
320-
batch_size=telemetry_batch_size,
321317
connection_uuid=self.get_session_id_hex(),
322318
auth_provider=auth_provider,
323319
user_agent=useragent_header,

src/databricks/sql/telemetry/models/enums.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ class AuthFlow(Enum):
99
class AuthMech(Enum):
1010
OTHER = "OTHER"
1111
PAT = "PAT"
12-
OAUTH = "OAUTH"
13-
EXTERNAL = "EXTERNAL"
12+
DATABRICKS_OAUTH = "DATABRICKS_OAUTH"
13+
EXTERNAL_AUTH = "EXTERNAL_AUTH"
1414

1515

1616
class DatabricksClientType(Enum):

src/databricks/sql/telemetry/models/event.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class EnumEncoder(json.JSONEncoder):
1616
"""
1717
Custom JSON encoder to handle Enum values.
1818
This is used to convert Enum values to their string representations.
19-
Default JSON encoder does not handle Enum values.
19+
Default JSON encoder raises a TypeError for Enums.
2020
"""
2121

2222
def default(self, obj):

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,14 @@ def get_auth_mechanism(auth_provider):
5858
if not auth_provider:
5959
return None
6060
if isinstance(auth_provider, AccessTokenAuthProvider):
61-
return AuthMech.PAT
61+
return AuthMech.PAT # Personal Access Token authentication
6262
elif isinstance(auth_provider, DatabricksOAuthProvider):
63-
return AuthMech.OAUTH
63+
return AuthMech.OAUTH # Databricks-managed OAuth flow
6464
elif isinstance(auth_provider, ExternalAuthProvider):
65-
return AuthMech.EXTERNAL
66-
return AuthMech.OTHER
65+
return (
66+
AuthMech.EXTERNAL
67+
) # External identity provider (AWS IAM, Azure AD, etc.)
68+
return AuthMech.OTHER # Custom or unknown authentication provider
6769

6870
@staticmethod
6971
def get_auth_flow(auth_provider):
@@ -72,14 +74,15 @@ def get_auth_flow(auth_provider):
7274
return None
7375

7476
if isinstance(auth_provider, DatabricksOAuthProvider):
75-
if (
76-
hasattr(auth_provider, "_refresh_token")
77-
and auth_provider._refresh_token
78-
):
79-
return AuthFlow.TOKEN_PASSTHROUGH
77+
if auth_provider._access_token and auth_provider._refresh_token:
78+
return (
79+
AuthFlow.TOKEN_PASSTHROUGH
80+
) # Has existing tokens, no user interaction needed
8081

8182
if hasattr(auth_provider, "oauth_manager"):
82-
return AuthFlow.BROWSER_BASED_AUTHENTICATION
83+
return (
84+
AuthFlow.BROWSER_BASED_AUTHENTICATION
85+
) # Will initiate OAuth flow requiring browser
8386

8487
return None
8588

@@ -140,15 +143,14 @@ class TelemetryClient(BaseTelemetryClient):
140143
def __init__(
141144
self,
142145
telemetry_enabled,
143-
batch_size,
144146
connection_uuid,
145147
auth_provider,
146148
user_agent,
147149
driver_connection_params,
148150
executor,
149151
):
150152
self.telemetry_enabled = telemetry_enabled
151-
self.batch_size = batch_size
153+
self.batch_size = 10 # TODO: Decide on batch size
152154
self.connection_uuid = connection_uuid
153155
self.auth_provider = auth_provider
154156
self.user_agent = user_agent
@@ -176,15 +178,12 @@ def flush(self):
176178

177179
def _send_telemetry(self, events):
178180
"""Send telemetry events to the server"""
179-
try:
180-
request = {
181-
"uploadTime": int(time.time() * 1000),
182-
"items": [],
183-
"protoLogs": [event.to_json() for event in events],
184-
}
185-
except Exception as e:
186-
print(f"[DEBUG] Error creating telemetry request: {e}", flush=True)
187-
raise e
181+
182+
request = {
183+
"uploadTime": int(time.time() * 1000),
184+
"items": [],
185+
"protoLogs": [event.to_json() for event in events],
186+
}
188187

189188
path = "/telemetry-ext" if self.auth_provider else "/telemetry-unauth"
190189
url = f"https://{self.host_url}{path}"
@@ -248,10 +247,9 @@ def __init__(self):
248247
) # Thread pool for async operations TODO: Decide on max workers
249248
self._initialized = True
250249

251-
def get_telemetry_client(
250+
def initialize_telemetry_client(
252251
self,
253252
telemetry_enabled,
254-
batch_size,
255253
connection_uuid,
256254
auth_provider,
257255
user_agent,
@@ -262,7 +260,6 @@ def get_telemetry_client(
262260
if connection_uuid not in self._clients:
263261
self._clients[connection_uuid] = TelemetryClient(
264262
telemetry_enabled=telemetry_enabled,
265-
batch_size=batch_size,
266263
connection_uuid=connection_uuid,
267264
auth_provider=auth_provider,
268265
user_agent=user_agent,
@@ -273,6 +270,13 @@ def get_telemetry_client(
273270
else:
274271
return NoopTelemetryClient()
275272

273+
def get_telemetry_client(self, connection_uuid):
274+
"""Get the telemetry client for a specific connection"""
275+
if connection_uuid in self._clients:
276+
return self._clients[connection_uuid]
277+
else:
278+
return NoopTelemetryClient()
279+
276280
def close(self, connection_uuid):
277281
if connection_uuid in self._clients:
278282
del self._clients[connection_uuid]

tests/unit/test_telemetry.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def setUp(self):
6060

6161
self.client = TelemetryClient(
6262
telemetry_enabled=True,
63-
batch_size=10,
6463
connection_uuid=self.connection_uuid,
6564
auth_provider=self.auth_provider,
6665
user_agent=self.user_agent,
@@ -123,7 +122,6 @@ def test_send_telemetry_unauthenticated(self, mock_post):
123122
"""Test sending telemetry to the server without authentication."""
124123
unauthenticated_client = TelemetryClient(
125124
telemetry_enabled=True,
126-
batch_size=10,
127125
connection_uuid=str(uuid.uuid4()),
128126
auth_provider=None, # No auth provider
129127
user_agent=self.user_agent,
@@ -190,9 +188,8 @@ def test_get_telemetry_client_enabled(self, mock_client_class):
190188
mock_client = MagicMock()
191189
mock_client_class.return_value = mock_client
192190

193-
client = self.factory.get_telemetry_client(
191+
client = self.factory.initialize_telemetry_client(
194192
telemetry_enabled=True,
195-
batch_size=10,
196193
connection_uuid=connection_uuid,
197194
auth_provider=auth_provider,
198195
user_agent=user_agent,
@@ -202,7 +199,6 @@ def test_get_telemetry_client_enabled(self, mock_client_class):
202199
# Verify a new client was created and stored
203200
mock_client_class.assert_called_once_with(
204201
telemetry_enabled=True,
205-
batch_size=10,
206202
connection_uuid=connection_uuid,
207203
auth_provider=auth_provider,
208204
user_agent=user_agent,
@@ -213,24 +209,16 @@ def test_get_telemetry_client_enabled(self, mock_client_class):
213209
self.assertEqual(self.factory._clients[connection_uuid], mock_client)
214210

215211
# Call again with the same connection_uuid
216-
client2 = self.factory.get_telemetry_client(
217-
telemetry_enabled=True,
218-
batch_size=10,
219-
connection_uuid=connection_uuid,
220-
auth_provider=auth_provider,
221-
user_agent=user_agent,
222-
driver_connection_params=driver_connection_params,
223-
)
212+
client2 = self.factory.get_telemetry_client(connection_uuid=connection_uuid)
224213

225214
# Verify the same client was returned and no new client was created
226215
self.assertEqual(client2, mock_client)
227216
mock_client_class.assert_called_once() # Still only called once
228217

229218
def test_get_telemetry_client_disabled(self):
230219
"""Test getting a telemetry client when telemetry is disabled."""
231-
client = self.factory.get_telemetry_client(
220+
client = self.factory.initialize_telemetry_client(
232221
telemetry_enabled=False,
233-
batch_size=10,
234222
connection_uuid="test-uuid",
235223
auth_provider=MagicMock(),
236224
user_agent="test-user-agent",
@@ -241,6 +229,9 @@ def test_get_telemetry_client_disabled(self):
241229
self.assertIsInstance(client, NoopTelemetryClient)
242230
self.assertEqual(self.factory._clients, {}) # No client was stored
243231

232+
client2 = self.factory.get_telemetry_client("test-uuid")
233+
self.assertIsInstance(client2, NoopTelemetryClient)
234+
244235
@patch("databricks.sql.telemetry.telemetry_client.ThreadPoolExecutor")
245236
def test_close(self, mock_executor_class):
246237
"""Test closing a client."""

0 commit comments

Comments
 (0)