Skip to content

Commit dba08cd

Browse files
rename proxy specific attrs with proxy prefix
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 328c3bd commit dba08cd

File tree

4 files changed

+53
-37
lines changed

4 files changed

+53
-37
lines changed

src/databricks/sql/backend/sea/utils/http_client.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ class SeaHttpClient:
3030
retry_policy: Union[DatabricksRetryPolicy, int]
3131
_pool: Optional[Union[HTTPConnectionPool, HTTPSConnectionPool]]
3232
proxy_uri: Optional[str]
33-
realhost: Optional[str]
34-
realport: Optional[int]
33+
proxy_host: Optional[str]
34+
proxy_port: Optional[int]
3535
proxy_auth: Optional[Dict[str, str]]
3636

3737
def __init__(
@@ -136,15 +136,15 @@ def __init__(
136136

137137
if proxy:
138138
parsed_proxy = urllib.parse.urlparse(proxy)
139-
self.realhost = self.host
140-
self.realport = self.port
139+
self.proxy_host = self.host
140+
self.proxy_port = self.port
141141
self.proxy_uri = proxy
142142
self.host = parsed_proxy.hostname
143143
self.port = parsed_proxy.port or (443 if self.scheme == "https" else 80)
144144
self.proxy_auth = self._basic_proxy_auth_headers(parsed_proxy)
145145
else:
146-
self.realhost = None
147-
self.realport = None
146+
self.proxy_host = None
147+
self.proxy_port = None
148148
self.proxy_auth = None
149149
self.proxy_uri = None
150150

@@ -186,8 +186,8 @@ def _open(self):
186186
proxy_headers=self.proxy_auth,
187187
)
188188
self._pool = proxy_manager.connection_from_host(
189-
host=self.realhost,
190-
port=self.realport,
189+
host=self.proxy_host,
190+
port=self.proxy_port,
191191
scheme=self.scheme,
192192
pool_kwargs=pool_kwargs,
193193
)
@@ -201,7 +201,7 @@ def close(self):
201201

202202
def using_proxy(self) -> bool:
203203
"""Check if proxy is being used."""
204-
return self.realhost is not None
204+
return self.proxy_host is not None
205205

206206
def set_retry_command_type(self, command_type: CommandType):
207207
"""Set the command type for retry policy decision making."""

tests/e2e/test_concurrent_telemetry.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66
import pytest
77

88
from databricks.sql.telemetry.models.enums import StatementType
9-
from databricks.sql.telemetry.telemetry_client import TelemetryClient, TelemetryClientFactory
9+
from databricks.sql.telemetry.telemetry_client import (
10+
TelemetryClient,
11+
TelemetryClientFactory,
12+
)
1013
from tests.e2e.test_driver import PySQLPytestTestCase
1114

15+
1216
def run_in_threads(target, num_threads, pass_index=False):
1317
"""Helper to run target function in multiple threads."""
1418
threads = [
@@ -22,7 +26,6 @@ def run_in_threads(target, num_threads, pass_index=False):
2226

2327

2428
class TestE2ETelemetry(PySQLPytestTestCase):
25-
2629
@pytest.fixture(autouse=True)
2730
def telemetry_setup_teardown(self):
2831
"""
@@ -31,7 +34,7 @@ def telemetry_setup_teardown(self):
3134
this robust and automatic.
3235
"""
3336
try:
34-
yield
37+
yield
3538
finally:
3639
if TelemetryClientFactory._executor:
3740
TelemetryClientFactory._executor.shutdown(wait=True)
@@ -68,20 +71,25 @@ def callback_wrapper(self_client, future, sent_count):
6871
captured_futures.append(future)
6972
original_callback(self_client, future, sent_count)
7073

71-
with patch.object(TelemetryClient, "_send_telemetry", send_telemetry_wrapper), \
72-
patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper):
74+
with patch.object(
75+
TelemetryClient, "_send_telemetry", send_telemetry_wrapper
76+
), patch.object(
77+
TelemetryClient, "_telemetry_request_callback", callback_wrapper
78+
):
7379

7480
def execute_query_worker(thread_id):
7581
"""Each thread creates a connection and executes a query."""
7682

7783
time.sleep(random.uniform(0, 0.05))
78-
79-
with self.connection(extra_params={"force_enable_telemetry": True}) as conn:
84+
85+
with self.connection(
86+
extra_params={"force_enable_telemetry": True}
87+
) as conn:
8088
# Capture the session ID from the connection before executing the query
8189
session_id_hex = conn.get_session_id_hex()
8290
with capture_lock:
8391
captured_session_ids.append(session_id_hex)
84-
92+
8593
with conn.cursor() as cursor:
8694
cursor.execute(f"SELECT {thread_id}")
8795
# Capture the statement ID after executing the query
@@ -97,7 +105,10 @@ def execute_query_worker(thread_id):
97105
start_time = time.time()
98106
expected_event_count = num_threads
99107

100-
while len(captured_futures) < expected_event_count and time.time() - start_time < timeout_seconds:
108+
while (
109+
len(captured_futures) < expected_event_count
110+
and time.time() - start_time < timeout_seconds
111+
):
101112
time.sleep(0.1)
102113

103114
done, not_done = wait(captured_futures, timeout=timeout_seconds)
@@ -115,30 +126,37 @@ def execute_query_worker(thread_id):
115126

116127
assert not captured_exceptions
117128
assert len(captured_responses) > 0
118-
129+
119130
total_successful_events = 0
120131
for response in captured_responses:
121132
assert "errors" not in response or not response["errors"]
122133
if "numProtoSuccess" in response:
123134
total_successful_events += response["numProtoSuccess"]
124135
assert total_successful_events == num_threads * 2
125136

126-
assert len(captured_telemetry) == num_threads * 2 # 2 events per thread (initial_telemetry_log, latency_log (execute))
137+
assert (
138+
len(captured_telemetry) == num_threads * 2
139+
) # 2 events per thread (initial_telemetry_log, latency_log (execute))
127140
assert len(captured_session_ids) == num_threads # One session ID per thread
128-
assert len(captured_statement_ids) == num_threads # One statement ID per thread (per query)
141+
assert (
142+
len(captured_statement_ids) == num_threads
143+
) # One statement ID per thread (per query)
129144

130145
# Separate initial logs from latency logs
131146
initial_logs = [
132-
e for e in captured_telemetry
147+
e
148+
for e in captured_telemetry
133149
if e.entry.sql_driver_log.operation_latency_ms is None
134150
and e.entry.sql_driver_log.driver_connection_params is not None
135151
and e.entry.sql_driver_log.system_configuration is not None
136152
]
137153
latency_logs = [
138-
e for e in captured_telemetry
139-
if e.entry.sql_driver_log.operation_latency_ms is not None
140-
and e.entry.sql_driver_log.sql_statement_id is not None
141-
and e.entry.sql_driver_log.sql_operation.statement_type == StatementType.QUERY
154+
e
155+
for e in captured_telemetry
156+
if e.entry.sql_driver_log.operation_latency_ms is not None
157+
and e.entry.sql_driver_log.sql_statement_id is not None
158+
and e.entry.sql_driver_log.sql_operation.statement_type
159+
== StatementType.QUERY
142160
]
143161

144162
# Verify counts
@@ -171,4 +189,4 @@ def execute_query_worker(thread_id):
171189
for event in latency_logs:
172190
log = event.entry.sql_driver_log
173191
assert log.sql_statement_id in captured_statement_ids
174-
assert log.session_id in captured_session_ids
192+
assert log.session_id in captured_session_ids

tests/unit/test_telemetry.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def mock_telemetry_client():
3030
auth_provider=auth_provider,
3131
host_url="test-host.com",
3232
executor=executor,
33-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
33+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
3434
)
3535

3636

@@ -215,7 +215,7 @@ def test_client_lifecycle_flow(self):
215215
session_id_hex=session_id_hex,
216216
auth_provider=auth_provider,
217217
host_url="test-host.com",
218-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
218+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
219219
)
220220

221221
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
@@ -240,7 +240,7 @@ def test_disabled_telemetry_flow(self):
240240
session_id_hex=session_id_hex,
241241
auth_provider=None,
242242
host_url="test-host.com",
243-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
243+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
244244
)
245245

246246
client = TelemetryClientFactory.get_telemetry_client(session_id_hex)
@@ -260,7 +260,7 @@ def test_factory_error_handling(self):
260260
session_id_hex=session_id,
261261
auth_provider=AccessTokenAuthProvider("token"),
262262
host_url="test-host.com",
263-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
263+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
264264
)
265265

266266
# Should fall back to NoopTelemetryClient
@@ -279,7 +279,7 @@ def test_factory_shutdown_flow(self):
279279
session_id_hex=session,
280280
auth_provider=AccessTokenAuthProvider("token"),
281281
host_url="test-host.com",
282-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
282+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
283283
)
284284

285285
# Factory should be initialized
@@ -342,9 +342,7 @@ def _mock_ff_response(self, mock_requests_get, enabled: bool):
342342
mock_requests_get.return_value = mock_response
343343

344344
@patch("databricks.sql.common.feature_flag.requests.get")
345-
def test_telemetry_enabled_when_flag_is_true(
346-
self, mock_requests_get, MockSession
347-
):
345+
def test_telemetry_enabled_when_flag_is_true(self, mock_requests_get, MockSession):
348346
"""Telemetry should be ON when enable_telemetry=True and server flag is 'true'."""
349347
self._mock_ff_response(mock_requests_get, enabled=True)
350348
mock_session_instance = MockSession.return_value
@@ -405,4 +403,4 @@ def test_telemetry_disabled_when_flag_request_fails(
405403
assert conn.telemetry_enabled is False
406404
mock_requests_get.assert_called_once()
407405
client = TelemetryClientFactory.get_telemetry_client("test-session-ff-fail")
408-
assert isinstance(client, NoopTelemetryClient)
406+
assert isinstance(client, NoopTelemetryClient)

tests/unit/test_telemetry_retry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_client(self, session_id, num_retries=3):
5151
session_id_hex=session_id,
5252
auth_provider=None,
5353
host_url="test.databricks.com",
54-
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE
54+
batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE,
5555
)
5656
client = TelemetryClientFactory.get_telemetry_client(session_id)
5757

0 commit comments

Comments
 (0)