Skip to content

Commit d7ab57f

Browse files
init cloud fetch stuffs
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 71266b1 commit d7ab57f

File tree

8 files changed

+797
-51
lines changed

8 files changed

+797
-51
lines changed

examples/experimental/sea_connector_test.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,126 @@ def test_sea_session():
273273
logger.info("SEA session test completed successfully")
274274

275275

276+
def test_sea_result_set_arrow_external_links():
277+
"""
278+
Test the SEA result set implementation with ARROW format and EXTERNAL_LINKS disposition.
279+
280+
This function connects to a Databricks SQL endpoint using the SEA backend,
281+
executes a query that returns a large result set (which will use EXTERNAL_LINKS disposition),
282+
and tests the various fetch methods to verify the result set implementation works correctly.
283+
"""
284+
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
285+
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
286+
access_token = os.environ.get("DATABRICKS_TOKEN")
287+
catalog = os.environ.get("DATABRICKS_CATALOG", "samples")
288+
schema = os.environ.get("DATABRICKS_SCHEMA", "tpch")
289+
290+
if not all([server_hostname, http_path, access_token]):
291+
logger.error("Missing required environment variables.")
292+
logger.error(
293+
"Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
294+
)
295+
sys.exit(1)
296+
297+
try:
298+
# Create connection with SEA backend
299+
logger.info("Creating connection with SEA backend...")
300+
connection = Connection(
301+
server_hostname=server_hostname,
302+
http_path=http_path,
303+
access_token=access_token,
304+
catalog=catalog,
305+
schema=schema,
306+
use_sea=True,
307+
use_cloud_fetch=True, # Enable cloud fetch to trigger EXTERNAL_LINKS + ARROW
308+
user_agent_entry="SEA-Test-Client",
309+
# Use a smaller arraysize to potentially force multiple chunks
310+
arraysize=1000,
311+
)
312+
313+
logger.info(
314+
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
315+
)
316+
317+
# Create cursor
318+
cursor = connection.cursor()
319+
320+
# Execute a query that returns a large result set (will use EXTERNAL_LINKS disposition)
321+
# Use a larger result set to ensure multiple chunks
322+
# Using a CROSS JOIN to generate a larger result set
323+
logger.info("Executing query: SELECT a.id as id1, b.id as id2 FROM range(1, 1000) a CROSS JOIN range(1, 1000) b LIMIT 100000")
324+
cursor.execute("SELECT a.id as id1, b.id as id2 FROM range(1, 1000) a CROSS JOIN range(1, 1000) b LIMIT 100000")
325+
326+
# Test the manifest to verify we're getting multiple chunks
327+
# We can't easily access the manifest in the SeaResultSet, so we'll just continue with the test
328+
# Note: The server might optimize results to fit into a single chunk, but our implementation
329+
# is designed to handle multiple chunks by fetching additional chunks when needed
330+
logger.info("Proceeding with fetch operations...")
331+
332+
# Test fetchone
333+
logger.info("Testing fetchone...")
334+
row = cursor.fetchone()
335+
logger.info(f"First row: {row}")
336+
337+
# Test fetchmany with a moderate size
338+
fetch_size = 500
339+
logger.info(f"Testing fetchmany({fetch_size})...")
340+
rows = cursor.fetchmany(fetch_size)
341+
logger.info(f"Fetched {len(rows)} rows with fetchmany")
342+
343+
# Test fetchall for remaining rows
344+
logger.info("Testing fetchall...")
345+
remaining_rows = cursor.fetchall()
346+
logger.info(f"Fetched {len(remaining_rows)} remaining rows with fetchall")
347+
348+
# Calculate total rows fetched
349+
total_rows = 1 + len(rows) + len(remaining_rows)
350+
logger.info(f"Total rows fetched: {total_rows}")
351+
352+
# Execute another query to test arrow fetch methods
353+
logger.info("\nExecuting second query for Arrow testing: SELECT * FROM range(1, 20000) as id LIMIT 20000")
354+
cursor.execute("SELECT * FROM range(1, 20000) as id LIMIT 20000")
355+
356+
try:
357+
# Test fetchmany_arrow with a moderate size
358+
arrow_fetch_size = 1000
359+
logger.info(f"Testing fetchmany_arrow({arrow_fetch_size})...")
360+
arrow_batch = cursor.fetchmany_arrow(arrow_fetch_size)
361+
logger.info(f"Arrow batch num rows: {arrow_batch.num_rows}")
362+
logger.info(f"Arrow batch columns: {arrow_batch.column_names}")
363+
364+
# Test fetchall_arrow
365+
logger.info("Testing fetchall_arrow...")
366+
remaining_arrow_batch = cursor.fetchall_arrow()
367+
logger.info(f"Remaining arrow batch num rows: {remaining_arrow_batch.num_rows}")
368+
369+
# Calculate total rows fetched with Arrow
370+
total_arrow_rows = arrow_batch.num_rows + remaining_arrow_batch.num_rows
371+
logger.info(f"Total rows fetched with Arrow: {total_arrow_rows}")
372+
373+
except ImportError:
374+
logger.warning("PyArrow not installed, skipping Arrow tests")
375+
376+
# Close cursor and connection
377+
cursor.close()
378+
connection.close()
379+
logger.info("Successfully closed SEA session")
380+
381+
except Exception as e:
382+
logger.error(f"Error during SEA result set test: {str(e)}")
383+
import traceback
384+
logger.error(traceback.format_exc())
385+
sys.exit(1)
386+
387+
logger.info("SEA result set test with ARROW format and EXTERNAL_LINKS disposition completed successfully")
388+
389+
276390
if __name__ == "__main__":
277391
# Test session management
278-
test_sea_session()
392+
# test_sea_session()
279393

280394
# Test result set implementation with metadata commands
281-
test_sea_result_set_json_array_inline()
395+
# test_sea_result_set_json_array_inline()
396+
397+
# Test result set implementation with ARROW format and EXTERNAL_LINKS disposition
398+
test_sea_result_set_arrow_external_links()

src/databricks/sql/backend/models/base.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ class ExternalLink:
3434
external_link: str
3535
expiration: str
3636
chunk_index: int
37+
byte_count: int = 0
38+
row_count: int = 0
39+
row_offset: int = 0
40+
next_chunk_index: Optional[int] = None
41+
next_chunk_internal_link: Optional[str] = None
42+
http_headers: Optional[Dict[str, str]] = None
3743

3844

3945
@dataclass
@@ -61,8 +67,11 @@ class ColumnInfo:
6167
class ResultManifest:
6268
"""Manifest information for a result set."""
6369

64-
schema: List[ColumnInfo]
70+
format: str
71+
schema: Dict[str, Any] # Will contain column information
6572
total_row_count: int
6673
total_byte_count: int
74+
total_chunk_count: int
6775
truncated: bool = False
68-
chunk_count: Optional[int] = None
76+
chunks: Optional[List[Dict[str, Any]]] = None
77+
result_compression: Optional[str] = None

src/databricks/sql/backend/models/responses.py

Lines changed: 125 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
ResultManifest,
1414
ResultData,
1515
ServiceError,
16+
ExternalLink,
17+
ColumnInfo,
1618
)
1719

1820

@@ -42,12 +44,55 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
4244
error=error,
4345
sql_state=status_data.get("sql_state"),
4446
)
47+
48+
# Parse manifest
49+
manifest = None
50+
if "manifest" in data:
51+
manifest_data = data["manifest"]
52+
manifest = ResultManifest(
53+
format=manifest_data.get("format", ""),
54+
schema=manifest_data.get("schema", {}),
55+
total_row_count=manifest_data.get("total_row_count", 0),
56+
total_byte_count=manifest_data.get("total_byte_count", 0),
57+
total_chunk_count=manifest_data.get("total_chunk_count", 0),
58+
truncated=manifest_data.get("truncated", False),
59+
chunks=manifest_data.get("chunks"),
60+
result_compression=manifest_data.get("result_compression"),
61+
)
62+
63+
# Parse result data
64+
result = None
65+
if "result" in data:
66+
result_data = data["result"]
67+
external_links = None
68+
69+
if "external_links" in result_data:
70+
external_links = []
71+
for link_data in result_data["external_links"]:
72+
external_links.append(
73+
ExternalLink(
74+
external_link=link_data.get("external_link", ""),
75+
expiration=link_data.get("expiration", ""),
76+
chunk_index=link_data.get("chunk_index", 0),
77+
byte_count=link_data.get("byte_count", 0),
78+
row_count=link_data.get("row_count", 0),
79+
row_offset=link_data.get("row_offset", 0),
80+
next_chunk_index=link_data.get("next_chunk_index"),
81+
next_chunk_internal_link=link_data.get("next_chunk_internal_link"),
82+
http_headers=link_data.get("http_headers"),
83+
)
84+
)
85+
86+
result = ResultData(
87+
data=result_data.get("data_array"),
88+
external_links=external_links,
89+
)
4590

4691
return cls(
4792
statement_id=data.get("statement_id", ""),
4893
status=status,
49-
manifest=data.get("manifest"), # We'll parse this more fully if needed
50-
result=data.get("result"), # We'll parse this more fully if needed
94+
manifest=manifest,
95+
result=result,
5196
)
5297

5398

@@ -77,12 +122,55 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
77122
error=error,
78123
sql_state=status_data.get("sql_state"),
79124
)
125+
126+
# Parse manifest
127+
manifest = None
128+
if "manifest" in data:
129+
manifest_data = data["manifest"]
130+
manifest = ResultManifest(
131+
format=manifest_data.get("format", ""),
132+
schema=manifest_data.get("schema", {}),
133+
total_row_count=manifest_data.get("total_row_count", 0),
134+
total_byte_count=manifest_data.get("total_byte_count", 0),
135+
total_chunk_count=manifest_data.get("total_chunk_count", 0),
136+
truncated=manifest_data.get("truncated", False),
137+
chunks=manifest_data.get("chunks"),
138+
result_compression=manifest_data.get("result_compression"),
139+
)
140+
141+
# Parse result data
142+
result = None
143+
if "result" in data:
144+
result_data = data["result"]
145+
external_links = None
146+
147+
if "external_links" in result_data:
148+
external_links = []
149+
for link_data in result_data["external_links"]:
150+
external_links.append(
151+
ExternalLink(
152+
external_link=link_data.get("external_link", ""),
153+
expiration=link_data.get("expiration", ""),
154+
chunk_index=link_data.get("chunk_index", 0),
155+
byte_count=link_data.get("byte_count", 0),
156+
row_count=link_data.get("row_count", 0),
157+
row_offset=link_data.get("row_offset", 0),
158+
next_chunk_index=link_data.get("next_chunk_index"),
159+
next_chunk_internal_link=link_data.get("next_chunk_internal_link"),
160+
http_headers=link_data.get("http_headers"),
161+
)
162+
)
163+
164+
result = ResultData(
165+
data=result_data.get("data_array"),
166+
external_links=external_links,
167+
)
80168

81169
return cls(
82170
statement_id=data.get("statement_id", ""),
83171
status=status,
84-
manifest=data.get("manifest"), # We'll parse this more fully if needed
85-
result=data.get("result"), # We'll parse this more fully if needed
172+
manifest=manifest,
173+
result=result,
86174
)
87175

88176

@@ -96,3 +184,36 @@ class CreateSessionResponse:
96184
def from_dict(cls, data: Dict[str, Any]) -> "CreateSessionResponse":
97185
"""Create a CreateSessionResponse from a dictionary."""
98186
return cls(session_id=data.get("session_id", ""))
187+
188+
189+
@dataclass
190+
class GetChunksResponse:
191+
"""Response from getting chunks for a statement."""
192+
193+
statement_id: str
194+
external_links: List[ExternalLink]
195+
196+
@classmethod
197+
def from_dict(cls, data: Dict[str, Any]) -> "GetChunksResponse":
198+
"""Create a GetChunksResponse from a dictionary."""
199+
external_links = []
200+
if "external_links" in data:
201+
for link_data in data["external_links"]:
202+
external_links.append(
203+
ExternalLink(
204+
external_link=link_data.get("external_link", ""),
205+
expiration=link_data.get("expiration", ""),
206+
chunk_index=link_data.get("chunk_index", 0),
207+
byte_count=link_data.get("byte_count", 0),
208+
row_count=link_data.get("row_count", 0),
209+
row_offset=link_data.get("row_offset", 0),
210+
next_chunk_index=link_data.get("next_chunk_index"),
211+
next_chunk_internal_link=link_data.get("next_chunk_internal_link"),
212+
http_headers=link_data.get("http_headers"),
213+
)
214+
)
215+
216+
return cls(
217+
statement_id=data.get("statement_id", ""),
218+
external_links=external_links,
219+
)

0 commit comments

Comments
 (0)