Skip to content

Commit ce08e01

Browse files
cloudfetch fix?
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 1fef8f3 commit ce08e01

File tree

7 files changed

+539
-56
lines changed

7 files changed

+539
-56
lines changed

examples/experimental/sea_connector_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"test_sea_sync_query",
1919
"test_sea_async_query",
2020
"test_sea_metadata",
21+
"test_sea_multi_chunk",
2122
]
2223

2324

@@ -27,6 +28,12 @@ def run_test_module(module_name: str) -> bool:
2728
os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py"
2829
)
2930

31+
# Handle the multi-chunk test which is in the main directory
32+
if module_name == "test_sea_multi_chunk":
33+
module_path = os.path.join(
34+
os.path.dirname(os.path.abspath(__file__)), f"{module_name}.py"
35+
)
36+
3037
# Simply run the module as a script - each module handles its own test execution
3138
result = subprocess.run(
3239
[sys.executable, module_path], capture_output=True, text=True
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
"""
2+
Test for SEA multi-chunk responses.
3+
4+
This script tests the SEA connector's ability to handle multi-chunk responses correctly.
5+
It runs queries that generate large rows to force multiple chunks and verifies that
6+
the correct number of rows are returned.
7+
"""
8+
import os
9+
import sys
10+
import logging
11+
import time
12+
from databricks.sql.client import Connection
13+
14+
logging.basicConfig(level=logging.INFO)
15+
logger = logging.getLogger(__name__)
16+
17+
18+
def test_sea_multi_chunk_with_cloud_fetch(requested_row_count=5000):
19+
"""
20+
Test executing a query that generates multiple chunks using cloud fetch.
21+
22+
Args:
23+
requested_row_count: Number of rows to request in the query
24+
25+
Returns:
26+
bool: True if the test passed, False otherwise
27+
"""
28+
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
29+
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
30+
access_token = os.environ.get("DATABRICKS_TOKEN")
31+
catalog = os.environ.get("DATABRICKS_CATALOG")
32+
33+
if not all([server_hostname, http_path, access_token]):
34+
logger.error("Missing required environment variables.")
35+
logger.error(
36+
"Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
37+
)
38+
return False
39+
40+
try:
41+
# Create connection with cloud fetch enabled
42+
logger.info("Creating connection for query execution with cloud fetch enabled")
43+
connection = Connection(
44+
server_hostname=server_hostname,
45+
http_path=http_path,
46+
access_token=access_token,
47+
catalog=catalog,
48+
schema="default",
49+
use_sea=True,
50+
user_agent_entry="SEA-Test-Client",
51+
use_cloud_fetch=True,
52+
)
53+
54+
logger.info(
55+
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
56+
)
57+
58+
# Execute a query that generates large rows to force multiple chunks
59+
cursor = connection.cursor()
60+
query = f"""
61+
SELECT
62+
id,
63+
concat('value_', repeat('a', 10000)) as test_value
64+
FROM range(1, {requested_row_count} + 1) AS t(id)
65+
"""
66+
67+
logger.info(
68+
f"Executing query with cloud fetch to generate {requested_row_count} rows"
69+
)
70+
start_time = time.time()
71+
cursor.execute(query)
72+
73+
# Fetch all rows
74+
rows = cursor.fetchall()
75+
actual_row_count = len(rows)
76+
end_time = time.time()
77+
78+
logger.info(f"Query executed in {end_time - start_time:.2f} seconds")
79+
logger.info(
80+
f"Requested {requested_row_count} rows, received {actual_row_count} rows"
81+
)
82+
83+
# Verify row count
84+
success = actual_row_count == requested_row_count
85+
if success:
86+
logger.info("✅ PASSED: Received correct number of rows")
87+
else:
88+
logger.error(
89+
f"❌ FAILED: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}"
90+
)
91+
92+
# Close resources
93+
cursor.close()
94+
connection.close()
95+
logger.info("Successfully closed SEA session")
96+
97+
return success
98+
99+
except Exception as e:
100+
logger.error(f"Error during SEA multi-chunk test with cloud fetch: {str(e)}")
101+
import traceback
102+
103+
logger.error(traceback.format_exc())
104+
return False
105+
106+
107+
def test_sea_multi_chunk_without_cloud_fetch(requested_row_count=100):
108+
"""
109+
Test executing a query that generates multiple chunks without using cloud fetch.
110+
111+
Args:
112+
requested_row_count: Number of rows to request in the query
113+
114+
Returns:
115+
bool: True if the test passed, False otherwise
116+
"""
117+
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
118+
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
119+
access_token = os.environ.get("DATABRICKS_TOKEN")
120+
catalog = os.environ.get("DATABRICKS_CATALOG")
121+
122+
if not all([server_hostname, http_path, access_token]):
123+
logger.error("Missing required environment variables.")
124+
logger.error(
125+
"Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
126+
)
127+
return False
128+
129+
try:
130+
# Create connection with cloud fetch disabled
131+
logger.info("Creating connection for query execution with cloud fetch disabled")
132+
connection = Connection(
133+
server_hostname=server_hostname,
134+
http_path=http_path,
135+
access_token=access_token,
136+
catalog=catalog,
137+
schema="default",
138+
use_sea=True,
139+
user_agent_entry="SEA-Test-Client",
140+
use_cloud_fetch=False,
141+
enable_query_result_lz4_compression=False,
142+
)
143+
144+
logger.info(
145+
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
146+
)
147+
148+
# For non-cloud fetch, use a smaller row count to avoid exceeding inline limits
149+
cursor = connection.cursor()
150+
query = f"""
151+
SELECT
152+
id,
153+
concat('value_', repeat('a', 100)) as test_value
154+
FROM range(1, {requested_row_count} + 1) AS t(id)
155+
"""
156+
157+
logger.info(
158+
f"Executing query without cloud fetch to generate {requested_row_count} rows"
159+
)
160+
start_time = time.time()
161+
cursor.execute(query)
162+
163+
# Fetch all rows
164+
rows = cursor.fetchall()
165+
actual_row_count = len(rows)
166+
end_time = time.time()
167+
168+
logger.info(f"Query executed in {end_time - start_time:.2f} seconds")
169+
logger.info(
170+
f"Requested {requested_row_count} rows, received {actual_row_count} rows"
171+
)
172+
173+
# Verify row count
174+
success = actual_row_count == requested_row_count
175+
if success:
176+
logger.info("✅ PASSED: Received correct number of rows")
177+
else:
178+
logger.error(
179+
f"❌ FAILED: Row count mismatch. Expected {requested_row_count}, got {actual_row_count}"
180+
)
181+
182+
# Close resources
183+
cursor.close()
184+
connection.close()
185+
logger.info("Successfully closed SEA session")
186+
187+
return success
188+
189+
except Exception as e:
190+
logger.error(f"Error during SEA multi-chunk test without cloud fetch: {str(e)}")
191+
import traceback
192+
193+
logger.error(traceback.format_exc())
194+
return False
195+
196+
197+
def main():
198+
# Check if required environment variables are set
199+
required_vars = [
200+
"DATABRICKS_SERVER_HOSTNAME",
201+
"DATABRICKS_HTTP_PATH",
202+
"DATABRICKS_TOKEN",
203+
]
204+
missing_vars = [var for var in required_vars if not os.environ.get(var)]
205+
206+
if missing_vars:
207+
logger.error(
208+
f"Missing required environment variables: {', '.join(missing_vars)}"
209+
)
210+
logger.error("Please set these variables before running the tests.")
211+
sys.exit(1)
212+
213+
# Get row count from command line or use default
214+
cloud_fetch_row_count = 5000
215+
non_cloud_fetch_row_count = 100
216+
217+
if len(sys.argv) > 1:
218+
try:
219+
cloud_fetch_row_count = int(sys.argv[1])
220+
except ValueError:
221+
logger.error(f"Invalid row count for cloud fetch: {sys.argv[1]}")
222+
logger.error("Please provide a valid integer for row count.")
223+
sys.exit(1)
224+
225+
if len(sys.argv) > 2:
226+
try:
227+
non_cloud_fetch_row_count = int(sys.argv[2])
228+
except ValueError:
229+
logger.error(f"Invalid row count for non-cloud fetch: {sys.argv[2]}")
230+
logger.error("Please provide a valid integer for row count.")
231+
sys.exit(1)
232+
233+
logger.info(
234+
f"Testing with {cloud_fetch_row_count} rows for cloud fetch and {non_cloud_fetch_row_count} rows for non-cloud fetch"
235+
)
236+
237+
# Test with cloud fetch
238+
with_cloud_fetch_success = test_sea_multi_chunk_with_cloud_fetch(
239+
cloud_fetch_row_count
240+
)
241+
logger.info(
242+
f"Multi-chunk test with cloud fetch: {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}"
243+
)
244+
245+
# Test without cloud fetch
246+
without_cloud_fetch_success = test_sea_multi_chunk_without_cloud_fetch(
247+
non_cloud_fetch_row_count
248+
)
249+
logger.info(
250+
f"Multi-chunk test without cloud fetch: {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}"
251+
)
252+
253+
# Compare results
254+
logger.info("\n=== RESULTS SUMMARY ===")
255+
logger.info(
256+
f"Cloud fetch test ({cloud_fetch_row_count} rows): {'✅ PASSED' if with_cloud_fetch_success else '❌ FAILED'}"
257+
)
258+
logger.info(
259+
f"Non-cloud fetch test ({non_cloud_fetch_row_count} rows): {'✅ PASSED' if without_cloud_fetch_success else '❌ FAILED'}"
260+
)
261+
262+
if with_cloud_fetch_success and without_cloud_fetch_success:
263+
logger.info("✅ ALL TESTS PASSED")
264+
sys.exit(0)
265+
else:
266+
logger.info("❌ SOME TESTS FAILED")
267+
sys.exit(1)
268+
269+
270+
if __name__ == "__main__":
271+
main()

0 commit comments

Comments
 (0)