1+ """
2+ Test for SEA multi-chunk responses.
3+
4+ This script tests the SEA connector's ability to handle multi-chunk responses correctly.
5+ It runs a query that generates large rows to force multiple chunks and verifies that
6+ the correct number of rows are returned.
7+ """
8+ import os
9+ import sys
10+ import logging
11+ import time
12+ import json
13+ import csv
14+ from pathlib import Path
15+ from databricks .sql .client import Connection
16+
17+ logging .basicConfig (level = logging .INFO )
18+ logger = logging .getLogger (__name__ )
19+
20+
21+ def test_sea_multi_chunk_with_cloud_fetch (requested_row_count = 5000 ):
22+ """
23+ Test executing a query that generates multiple chunks using cloud fetch.
24+
25+ Args:
26+ requested_row_count: Number of rows to request in the query
27+
28+ Returns:
29+ bool: True if the test passed, False otherwise
30+ """
31+ server_hostname = os .environ .get ("DATABRICKS_SERVER_HOSTNAME" )
32+ http_path = os .environ .get ("DATABRICKS_HTTP_PATH" )
33+ access_token = os .environ .get ("DATABRICKS_TOKEN" )
34+ catalog = os .environ .get ("DATABRICKS_CATALOG" )
35+
36+ # Create output directory for test results
37+ output_dir = Path ("test_results" )
38+ output_dir .mkdir (exist_ok = True )
39+
40+ # Files to store results
41+ rows_file = output_dir / "cloud_fetch_rows.csv"
42+ stats_file = output_dir / "cloud_fetch_stats.json"
43+
44+ if not all ([server_hostname , http_path , access_token ]):
45+ logger .error ("Missing required environment variables." )
46+ logger .error (
47+ "Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
48+ )
49+ return False
50+
51+ try :
52+ # Create connection with cloud fetch enabled
53+ logger .info (
54+ "Creating connection for query execution with cloud fetch enabled"
55+ )
56+ connection = Connection (
57+ server_hostname = server_hostname ,
58+ http_path = http_path ,
59+ access_token = access_token ,
60+ catalog = catalog ,
61+ schema = "default" ,
62+ use_sea = True ,
63+ user_agent_entry = "SEA-Test-Client" ,
64+ use_cloud_fetch = True ,
65+ )
66+
67+ logger .info (
68+ f"Successfully opened SEA session with ID: { connection .get_session_id_hex ()} "
69+ )
70+
71+ # Execute a query that generates large rows to force multiple chunks
72+ cursor = connection .cursor ()
73+ query = f"""
74+ SELECT
75+ id,
76+ concat('value_', repeat('a', 10000)) as test_value
77+ FROM range(1, { requested_row_count } + 1) AS t(id)
78+ """
79+
80+ logger .info (f"Executing query with cloud fetch to generate { requested_row_count } rows" )
81+ start_time = time .time ()
82+ cursor .execute (query )
83+
84+ # Fetch all rows
85+ rows = cursor .fetchall ()
86+ actual_row_count = len (rows )
87+ end_time = time .time ()
88+ execution_time = end_time - start_time
89+
90+ logger .info (f"Query executed in { execution_time :.2f} seconds" )
91+ logger .info (f"Requested { requested_row_count } rows, received { actual_row_count } rows" )
92+
93+ # Write rows to CSV file for inspection
94+ logger .info (f"Writing rows to { rows_file } " )
95+ with open (rows_file , 'w' , newline = '' ) as f :
96+ writer = csv .writer (f )
97+ writer .writerow (['id' , 'value_length' ]) # Header
98+
99+ # Extract IDs to check for duplicates and missing values
100+ row_ids = []
101+ for row in rows :
102+ row_id = row [0 ]
103+ value_length = len (row [1 ])
104+ writer .writerow ([row_id , value_length ])
105+ row_ids .append (row_id )
106+
107+ # Verify row count
108+ success = actual_row_count == requested_row_count
109+
110+ # Check for duplicate IDs
111+ unique_ids = set (row_ids )
112+ duplicate_count = len (row_ids ) - len (unique_ids )
113+
114+ # Check for missing IDs
115+ expected_ids = set (range (1 , requested_row_count + 1 ))
116+ missing_ids = expected_ids - unique_ids
117+ extra_ids = unique_ids - expected_ids
118+
119+ # Write statistics to JSON file
120+ stats = {
121+ "requested_row_count" : requested_row_count ,
122+ "actual_row_count" : actual_row_count ,
123+ "execution_time_seconds" : execution_time ,
124+ "duplicate_count" : duplicate_count ,
125+ "missing_ids_count" : len (missing_ids ),
126+ "extra_ids_count" : len (extra_ids ),
127+ "missing_ids" : list (missing_ids )[:100 ] if missing_ids else [], # Limit to first 100 for readability
128+ "extra_ids" : list (extra_ids )[:100 ] if extra_ids else [], # Limit to first 100 for readability
129+ "success" : success and duplicate_count == 0 and len (missing_ids ) == 0 and len (extra_ids ) == 0
130+ }
131+
132+ with open (stats_file , 'w' ) as f :
133+ json .dump (stats , f , indent = 2 )
134+
135+ # Log detailed results
136+ if duplicate_count > 0 :
137+ logger .error (f"❌ FAILED: Found { duplicate_count } duplicate row IDs" )
138+ success = False
139+ else :
140+ logger .info ("✅ PASSED: No duplicate row IDs found" )
141+
142+ if missing_ids :
143+ logger .error (f"❌ FAILED: Missing { len (missing_ids )} expected row IDs" )
144+ if len (missing_ids ) <= 10 :
145+ logger .error (f"Missing IDs: { sorted (list (missing_ids ))} " )
146+ success = False
147+ else :
148+ logger .info ("✅ PASSED: All expected row IDs present" )
149+
150+ if extra_ids :
151+ logger .error (f"❌ FAILED: Found { len (extra_ids )} unexpected row IDs" )
152+ if len (extra_ids ) <= 10 :
153+ logger .error (f"Extra IDs: { sorted (list (extra_ids ))} " )
154+ success = False
155+ else :
156+ logger .info ("✅ PASSED: No unexpected row IDs found" )
157+
158+ if actual_row_count == requested_row_count :
159+ logger .info ("✅ PASSED: Row count matches requested count" )
160+ else :
161+ logger .error (f"❌ FAILED: Row count mismatch. Expected { requested_row_count } , got { actual_row_count } " )
162+ success = False
163+
164+ # Close resources
165+ cursor .close ()
166+ connection .close ()
167+ logger .info ("Successfully closed SEA session" )
168+
169+ logger .info (f"Test results written to { rows_file } and { stats_file } " )
170+ return success
171+
172+ except Exception as e :
173+ logger .error (
174+ f"Error during SEA multi-chunk test with cloud fetch: { str (e )} "
175+ )
176+ import traceback
177+ logger .error (traceback .format_exc ())
178+ return False
179+
180+
181+ def main ():
182+ # Check if required environment variables are set
183+ required_vars = [
184+ "DATABRICKS_SERVER_HOSTNAME" ,
185+ "DATABRICKS_HTTP_PATH" ,
186+ "DATABRICKS_TOKEN" ,
187+ ]
188+ missing_vars = [var for var in required_vars if not os .environ .get (var )]
189+
190+ if missing_vars :
191+ logger .error (
192+ f"Missing required environment variables: { ', ' .join (missing_vars )} "
193+ )
194+ logger .error ("Please set these variables before running the tests." )
195+ sys .exit (1 )
196+
197+ # Get row count from command line or use default
198+ requested_row_count = 5000
199+
200+ if len (sys .argv ) > 1 :
201+ try :
202+ requested_row_count = int (sys .argv [1 ])
203+ except ValueError :
204+ logger .error (f"Invalid row count: { sys .argv [1 ]} " )
205+ logger .error ("Please provide a valid integer for row count." )
206+ sys .exit (1 )
207+
208+ logger .info (f"Testing with { requested_row_count } rows" )
209+
210+ # Run the multi-chunk test with cloud fetch
211+ success = test_sea_multi_chunk_with_cloud_fetch (requested_row_count )
212+
213+ # Report results
214+ if success :
215+ logger .info ("✅ TEST PASSED: Multi-chunk cloud fetch test completed successfully" )
216+ sys .exit (0 )
217+ else :
218+ logger .error ("❌ TEST FAILED: Multi-chunk cloud fetch test encountered errors" )
219+ sys .exit (1 )
220+
221+
222+ if __name__ == "__main__" :
223+ main ()
0 commit comments