Skip to content

Commit c62f76d

Browse files
remove un-necessary changes
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent c075b07 commit c62f76d

File tree

9 files changed

+240
-296
lines changed

9 files changed

+240
-296
lines changed
Lines changed: 56 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,115 +1,66 @@
1-
"""
2-
Main script to run all SEA connector tests.
3-
4-
This script runs all the individual test modules and displays
5-
a summary of test results with visual indicators.
6-
"""
71
import os
82
import sys
93
import logging
10-
import subprocess
11-
from typing import List, Tuple
4+
from databricks.sql.client import Connection
125

136
logging.basicConfig(level=logging.DEBUG)
147
logger = logging.getLogger(__name__)
158

16-
TEST_MODULES = [
17-
"test_sea_session",
18-
"test_sea_sync_query",
19-
"test_sea_async_query",
20-
"test_sea_metadata",
21-
]
22-
23-
24-
def run_test_module(module_name: str) -> bool:
25-
"""Run a test module and return success status."""
26-
module_path = os.path.join(
27-
os.path.dirname(os.path.abspath(__file__)), "tests", f"{module_name}.py"
28-
)
29-
30-
# Simply run the module as a script - each module handles its own test execution
31-
result = subprocess.run(
32-
[sys.executable, module_path], capture_output=True, text=True
33-
)
34-
35-
# Log the output from the test module
36-
if result.stdout:
37-
for line in result.stdout.strip().split("\n"):
38-
logger.info(line)
39-
40-
if result.stderr:
41-
for line in result.stderr.strip().split("\n"):
42-
logger.error(line)
43-
44-
return result.returncode == 0
45-
46-
47-
def run_tests() -> List[Tuple[str, bool]]:
48-
"""Run all tests and return results."""
49-
results = []
50-
51-
for module_name in TEST_MODULES:
52-
try:
53-
logger.info(f"\n{'=' * 50}")
54-
logger.info(f"Running test: {module_name}")
55-
logger.info(f"{'-' * 50}")
56-
57-
success = run_test_module(module_name)
58-
results.append((module_name, success))
59-
60-
status = "✅ PASSED" if success else "❌ FAILED"
61-
logger.info(f"Test {module_name}: {status}")
62-
63-
except Exception as e:
64-
logger.error(f"Error loading or running test {module_name}: {str(e)}")
65-
import traceback
66-
67-
logger.error(traceback.format_exc())
68-
results.append((module_name, False))
69-
70-
return results
71-
72-
73-
def print_summary(results: List[Tuple[str, bool]]) -> None:
74-
"""Print a summary of test results."""
75-
logger.info(f"\n{'=' * 50}")
76-
logger.info("TEST SUMMARY")
77-
logger.info(f"{'-' * 50}")
78-
79-
passed = sum(1 for _, success in results if success)
80-
total = len(results)
81-
82-
for module_name, success in results:
83-
status = "✅ PASSED" if success else "❌ FAILED"
84-
logger.info(f"{status} - {module_name}")
85-
86-
logger.info(f"{'-' * 50}")
87-
logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}")
88-
logger.info(f"{'=' * 50}")
89-
90-
91-
if __name__ == "__main__":
92-
# Check if required environment variables are set
93-
required_vars = [
94-
"DATABRICKS_SERVER_HOSTNAME",
95-
"DATABRICKS_HTTP_PATH",
96-
"DATABRICKS_TOKEN",
97-
]
98-
missing_vars = [var for var in required_vars if not os.environ.get(var)]
99-
100-
if missing_vars:
101-
logger.error(
102-
f"Missing required environment variables: {', '.join(missing_vars)}"
9+
def test_sea_session():
10+
"""
11+
Test opening and closing a SEA session using the connector.
12+
13+
This function connects to a Databricks SQL endpoint using the SEA backend,
14+
opens a session, and then closes it.
15+
16+
Required environment variables:
17+
- DATABRICKS_SERVER_HOSTNAME: Databricks server hostname
18+
- DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint
19+
- DATABRICKS_TOKEN: Personal access token for authentication
20+
"""
21+
22+
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
23+
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
24+
access_token = os.environ.get("DATABRICKS_TOKEN")
25+
catalog = os.environ.get("DATABRICKS_CATALOG")
26+
27+
if not all([server_hostname, http_path, access_token]):
28+
logger.error("Missing required environment variables.")
29+
logger.error("Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN.")
30+
sys.exit(1)
31+
32+
logger.info(f"Connecting to {server_hostname}")
33+
logger.info(f"HTTP Path: {http_path}")
34+
if catalog:
35+
logger.info(f"Using catalog: {catalog}")
36+
37+
try:
38+
logger.info("Creating connection with SEA backend...")
39+
connection = Connection(
40+
server_hostname=server_hostname,
41+
http_path=http_path,
42+
access_token=access_token,
43+
catalog=catalog,
44+
schema="default",
45+
use_sea=True,
46+
user_agent_entry="SEA-Test-Client" # add custom user agent
10347
)
104-
logger.error("Please set these variables before running the tests.")
48+
49+
logger.info(f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}")
50+
logger.info(f"backend type: {type(connection.session.backend)}")
51+
52+
# Close the connection
53+
logger.info("Closing the SEA session...")
54+
connection.close()
55+
logger.info("Successfully closed SEA session")
56+
57+
except Exception as e:
58+
logger.error(f"Error testing SEA session: {str(e)}")
59+
import traceback
60+
logger.error(traceback.format_exc())
10561
sys.exit(1)
62+
63+
logger.info("SEA session test completed successfully")
10664

107-
# Run all tests
108-
results = run_tests()
109-
110-
# Print summary
111-
print_summary(results)
112-
113-
# Exit with appropriate status code
114-
all_passed = all(success for _, success in results)
115-
sys.exit(0 if all_passed else 1)
65+
if __name__ == "__main__":
66+
test_sea_session()

src/databricks/sql/backend/sea/models/base.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,29 @@ class ExternalLink:
4242
http_headers: Optional[Dict[str, str]] = None
4343

4444

45+
@dataclass
46+
class ChunkInfo:
47+
"""Information about a chunk in the result set."""
48+
49+
chunk_index: int
50+
byte_count: int
51+
row_offset: int
52+
row_count: int
53+
54+
4555
@dataclass
4656
class ResultData:
4757
"""Result data from a statement execution."""
4858

4959
data: Optional[List[List[Any]]] = None
5060
external_links: Optional[List[ExternalLink]] = None
61+
byte_count: Optional[int] = None
62+
chunk_index: Optional[int] = None
63+
next_chunk_index: Optional[int] = None
64+
next_chunk_internal_link: Optional[str] = None
65+
row_count: Optional[int] = None
66+
row_offset: Optional[int] = None
67+
attachment: Optional[bytes] = None
5168

5269

5370
@dataclass
@@ -73,5 +90,6 @@ class ResultManifest:
7390
total_byte_count: int
7491
total_chunk_count: int
7592
truncated: bool = False
76-
chunks: Optional[List[Dict[str, Any]]] = None
93+
chunks: Optional[List[ChunkInfo]] = None
7794
result_compression: Optional[str] = None
95+
is_volume_operation: Optional[bool] = None

src/databricks/sql/backend/sea/models/requests.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
@dataclass
1212
class StatementParameter:
13-
"""Parameter for a SQL statement."""
13+
"""Representation of a parameter for a SQL statement."""
1414

1515
name: str
1616
value: Optional[str] = None
@@ -19,7 +19,7 @@ class StatementParameter:
1919

2020
@dataclass
2121
class ExecuteStatementRequest:
22-
"""Request to execute a SQL statement."""
22+
"""Representation of a request to execute a SQL statement."""
2323

2424
session_id: str
2525
statement: str
@@ -65,7 +65,7 @@ def to_dict(self) -> Dict[str, Any]:
6565

6666
@dataclass
6767
class GetStatementRequest:
68-
"""Request to get information about a statement."""
68+
"""Representation of a request to get information about a statement."""
6969

7070
statement_id: str
7171

@@ -76,7 +76,7 @@ def to_dict(self) -> Dict[str, Any]:
7676

7777
@dataclass
7878
class CancelStatementRequest:
79-
"""Request to cancel a statement."""
79+
"""Representation of a request to cancel a statement."""
8080

8181
statement_id: str
8282

@@ -87,7 +87,7 @@ def to_dict(self) -> Dict[str, Any]:
8787

8888
@dataclass
8989
class CloseStatementRequest:
90-
"""Request to close a statement."""
90+
"""Representation of a request to close a statement."""
9191

9292
statement_id: str
9393

@@ -98,7 +98,7 @@ def to_dict(self) -> Dict[str, Any]:
9898

9999
@dataclass
100100
class CreateSessionRequest:
101-
"""Request to create a new session."""
101+
"""Representation of a request to create a new session."""
102102

103103
warehouse_id: str
104104
session_confs: Optional[Dict[str, str]] = None
@@ -123,7 +123,7 @@ def to_dict(self) -> Dict[str, Any]:
123123

124124
@dataclass
125125
class DeleteSessionRequest:
126-
"""Request to delete a session."""
126+
"""Representation of a request to delete a session."""
127127

128128
warehouse_id: str
129129
session_id: str

src/databricks/sql/backend/sea/models/responses.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ResultData,
1515
ServiceError,
1616
ExternalLink,
17+
ChunkInfo,
1718
)
1819

1920

@@ -43,15 +44,28 @@ def _parse_manifest(data: Dict[str, Any]) -> ResultManifest:
4344
"""Parse manifest from response data."""
4445

4546
manifest_data = data.get("manifest", {})
47+
chunks = None
48+
if "chunks" in manifest_data:
49+
chunks = [
50+
ChunkInfo(
51+
chunk_index=chunk.get("chunk_index", 0),
52+
byte_count=chunk.get("byte_count", 0),
53+
row_offset=chunk.get("row_offset", 0),
54+
row_count=chunk.get("row_count", 0),
55+
)
56+
for chunk in manifest_data.get("chunks", [])
57+
]
58+
4659
return ResultManifest(
4760
format=manifest_data.get("format", ""),
4861
schema=manifest_data.get("schema", {}),
4962
total_row_count=manifest_data.get("total_row_count", 0),
5063
total_byte_count=manifest_data.get("total_byte_count", 0),
5164
total_chunk_count=manifest_data.get("total_chunk_count", 0),
5265
truncated=manifest_data.get("truncated", False),
53-
chunks=manifest_data.get("chunks"),
66+
chunks=chunks,
5467
result_compression=manifest_data.get("result_compression"),
68+
is_volume_operation=manifest_data.get("is_volume_operation"),
5569
)
5670

5771

@@ -80,12 +94,19 @@ def _parse_result(data: Dict[str, Any]) -> ResultData:
8094
return ResultData(
8195
data=result_data.get("data_array"),
8296
external_links=external_links,
97+
byte_count=result_data.get("byte_count"),
98+
chunk_index=result_data.get("chunk_index"),
99+
next_chunk_index=result_data.get("next_chunk_index"),
100+
next_chunk_internal_link=result_data.get("next_chunk_internal_link"),
101+
row_count=result_data.get("row_count"),
102+
row_offset=result_data.get("row_offset"),
103+
attachment=result_data.get("attachment"),
83104
)
84105

85106

86107
@dataclass
87108
class ExecuteStatementResponse:
88-
"""Response from executing a SQL statement."""
109+
"""Representation of the response from executing a SQL statement."""
89110

90111
statement_id: str
91112
status: StatementStatus
@@ -105,7 +126,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ExecuteStatementResponse":
105126

106127
@dataclass
107128
class GetStatementResponse:
108-
"""Response from getting information about a statement."""
129+
"""Representation of the response from getting information about a statement."""
109130

110131
statement_id: str
111132
status: StatementStatus
@@ -125,7 +146,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "GetStatementResponse":
125146

126147
@dataclass
127148
class CreateSessionResponse:
128-
"""Response from creating a new session."""
149+
"""Representation of the response from creating a new session."""
129150

130151
session_id: str
131152

src/databricks/sql/backend/sea/utils/http_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33
import requests
4-
from typing import Callable, Dict, Any, Optional, Union, List, Tuple
4+
from typing import Callable, Dict, Any, Optional, List, Tuple
55
from urllib.parse import urljoin
66

77
from databricks.sql.auth.authenticators import AuthProvider

0 commit comments

Comments
 (0)