Skip to content

Commit 92b778c

Browse files
cleaner, organised test scripts
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent ac381f3 commit 92b778c

File tree

8 files changed

+574
-160
lines changed

8 files changed

+574
-160
lines changed
Lines changed: 103 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -1,162 +1,111 @@
1+
"""
2+
Main script to run all SEA connector tests.
3+
4+
This script imports and runs all the individual test modules and displays
5+
a summary of test results with visual indicators.
6+
"""
17
import os
28
import sys
39
import logging
4-
from databricks.sql.client import Connection
10+
import importlib.util
11+
from typing import Dict, Callable, List, Tuple
512

6-
logging.basicConfig(level=logging.DEBUG)
13+
# Configure logging
14+
logging.basicConfig(level=logging.INFO)
715
logger = logging.getLogger(__name__)
816

9-
10-
def test_sea_query_exec():
11-
"""
12-
Test executing a query using the SEA backend with result compression.
13-
14-
This function connects to a Databricks SQL endpoint using the SEA backend,
15-
executes a simple query with result compression enabled and disabled,
16-
and verifies that execution completes successfully.
17-
"""
18-
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
19-
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
20-
access_token = os.environ.get("DATABRICKS_TOKEN")
21-
catalog = os.environ.get("DATABRICKS_CATALOG")
22-
23-
if not all([server_hostname, http_path, access_token]):
24-
logger.error("Missing required environment variables.")
25-
logger.error(
26-
"Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
27-
)
28-
sys.exit(1)
29-
30-
try:
31-
# Test with compression enabled
32-
logger.info("Creating connection with LZ4 compression enabled")
33-
connection = Connection(
34-
server_hostname=server_hostname,
35-
http_path=http_path,
36-
access_token=access_token,
37-
catalog=catalog,
38-
schema="default",
39-
use_sea=True,
40-
user_agent_entry="SEA-Test-Client",
41-
use_cloud_fetch=True, # Enable cloud fetch to use compression
42-
enable_query_result_lz4_compression=True, # Enable LZ4 compression
43-
)
44-
45-
logger.info(
46-
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
47-
)
48-
logger.info(f"backend type: {type(connection.session.backend)}")
49-
50-
# Execute a simple query with compression enabled
51-
cursor = connection.cursor(arraysize=0, buffer_size_bytes=0)
52-
logger.info("Executing query with LZ4 compression: SELECT 1 as test_value")
53-
cursor.execute("SELECT 1 as test_value")
54-
logger.info("Query with compression executed successfully")
55-
cursor.close()
56-
connection.close()
57-
logger.info("Successfully closed SEA session with compression enabled")
58-
59-
# Test with compression disabled
60-
logger.info("Creating connection with LZ4 compression disabled")
61-
connection = Connection(
62-
server_hostname=server_hostname,
63-
http_path=http_path,
64-
access_token=access_token,
65-
catalog=catalog,
66-
schema="default",
67-
use_sea=True,
68-
user_agent_entry="SEA-Test-Client",
69-
use_cloud_fetch=False, # Enable cloud fetch
70-
enable_query_result_lz4_compression=False, # Disable LZ4 compression
71-
)
72-
73-
logger.info(
74-
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
75-
)
76-
77-
# Execute a simple query with compression disabled
78-
cursor = connection.cursor(arraysize=0, buffer_size_bytes=0)
79-
logger.info("Executing query without compression: SELECT 1 as test_value")
80-
cursor.execute("SELECT 1 as test_value")
81-
logger.info("Query without compression executed successfully")
82-
cursor.close()
83-
connection.close()
84-
logger.info("Successfully closed SEA session with compression disabled")
85-
86-
except Exception as e:
87-
logger.error(f"Error during SEA query execution test: {str(e)}")
88-
import traceback
89-
90-
logger.error(traceback.format_exc())
91-
sys.exit(1)
92-
93-
logger.info("SEA query execution test with compression completed successfully")
94-
95-
96-
def test_sea_session():
97-
"""
98-
Test opening and closing a SEA session using the connector.
99-
100-
This function connects to a Databricks SQL endpoint using the SEA backend,
101-
opens a session, and then closes it.
102-
103-
Required environment variables:
104-
- DATABRICKS_SERVER_HOSTNAME: Databricks server hostname
105-
- DATABRICKS_HTTP_PATH: HTTP path for the SQL endpoint
106-
- DATABRICKS_TOKEN: Personal access token for authentication
107-
"""
108-
server_hostname = os.environ.get("DATABRICKS_SERVER_HOSTNAME")
109-
http_path = os.environ.get("DATABRICKS_HTTP_PATH")
110-
access_token = os.environ.get("DATABRICKS_TOKEN")
111-
catalog = os.environ.get("DATABRICKS_CATALOG")
112-
113-
if not all([server_hostname, http_path, access_token]):
114-
logger.error("Missing required environment variables.")
115-
logger.error(
116-
"Please set DATABRICKS_SERVER_HOSTNAME, DATABRICKS_HTTP_PATH, and DATABRICKS_TOKEN."
117-
)
118-
sys.exit(1)
119-
120-
logger.info(f"Connecting to {server_hostname}")
121-
logger.info(f"HTTP Path: {http_path}")
122-
if catalog:
123-
logger.info(f"Using catalog: {catalog}")
124-
125-
try:
126-
logger.info("Creating connection with SEA backend...")
127-
connection = Connection(
128-
server_hostname=server_hostname,
129-
http_path=http_path,
130-
access_token=access_token,
131-
catalog=catalog,
132-
schema="default",
133-
use_sea=True,
134-
user_agent_entry="SEA-Test-Client", # add custom user agent
135-
)
136-
137-
logger.info(
138-
f"Successfully opened SEA session with ID: {connection.get_session_id_hex()}"
139-
)
140-
logger.info(f"backend type: {type(connection.session.backend)}")
141-
142-
# Close the connection
143-
logger.info("Closing the SEA session...")
144-
connection.close()
145-
logger.info("Successfully closed SEA session")
146-
147-
except Exception as e:
148-
logger.error(f"Error testing SEA session: {str(e)}")
149-
import traceback
150-
151-
logger.error(traceback.format_exc())
152-
sys.exit(1)
153-
154-
logger.info("SEA session test completed successfully")
155-
17+
# Define test modules and their main test functions
18+
TEST_MODULES = [
19+
"test_sea_session",
20+
"test_sea_sync_query",
21+
"test_sea_async_query",
22+
"test_sea_metadata",
23+
]
24+
25+
def load_test_function(module_name: str) -> Callable:
26+
"""Load a test function from a module."""
27+
module_path = os.path.join(
28+
os.path.dirname(os.path.abspath(__file__)),
29+
"tests",
30+
f"{module_name}.py"
31+
)
32+
33+
spec = importlib.util.spec_from_file_location(module_name, module_path)
34+
module = importlib.util.module_from_spec(spec)
35+
spec.loader.exec_module(module)
36+
37+
# Get the main test function (assuming it starts with "test_")
38+
for name in dir(module):
39+
if name.startswith("test_") and callable(getattr(module, name)):
40+
# For sync and async query modules, we want the main function that runs both tests
41+
if name == f"test_sea_{module_name.replace('test_sea_', '')}_exec":
42+
return getattr(module, name)
43+
44+
# Fallback to the first test function found
45+
for name in dir(module):
46+
if name.startswith("test_") and callable(getattr(module, name)):
47+
return getattr(module, name)
48+
49+
raise ValueError(f"No test function found in module {module_name}")
50+
51+
def run_tests() -> List[Tuple[str, bool]]:
52+
"""Run all tests and return results."""
53+
results = []
54+
55+
for module_name in TEST_MODULES:
56+
try:
57+
test_func = load_test_function(module_name)
58+
logger.info(f"\n{'=' * 50}")
59+
logger.info(f"Running test: {module_name}")
60+
logger.info(f"{'-' * 50}")
61+
62+
success = test_func()
63+
results.append((module_name, success))
64+
65+
status = "✅ PASSED" if success else "❌ FAILED"
66+
logger.info(f"Test {module_name}: {status}")
67+
68+
except Exception as e:
69+
logger.error(f"Error loading or running test {module_name}: {str(e)}")
70+
import traceback
71+
logger.error(traceback.format_exc())
72+
results.append((module_name, False))
73+
74+
return results
75+
76+
def print_summary(results: List[Tuple[str, bool]]) -> None:
77+
"""Print a summary of test results."""
78+
logger.info(f"\n{'=' * 50}")
79+
logger.info("TEST SUMMARY")
80+
logger.info(f"{'-' * 50}")
81+
82+
passed = sum(1 for _, success in results if success)
83+
total = len(results)
84+
85+
for module_name, success in results:
86+
status = "✅ PASSED" if success else "❌ FAILED"
87+
logger.info(f"{status} - {module_name}")
88+
89+
logger.info(f"{'-' * 50}")
90+
logger.info(f"Total: {total} | Passed: {passed} | Failed: {total - passed}")
91+
logger.info(f"{'=' * 50}")
15692

15793
if __name__ == "__main__":
158-
# Test session management
159-
test_sea_session()
160-
161-
# Test query execution with compression
162-
test_sea_query_exec()
94+
# Check if required environment variables are set
95+
required_vars = ["DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_TOKEN"]
96+
missing_vars = [var for var in required_vars if not os.environ.get(var)]
97+
98+
if missing_vars:
99+
logger.error(f"Missing required environment variables: {', '.join(missing_vars)}")
100+
logger.error("Please set these variables before running the tests.")
101+
sys.exit(1)
102+
103+
# Run all tests
104+
results = run_tests()
105+
106+
# Print summary
107+
print_summary(results)
108+
109+
# Exit with appropriate status code
110+
all_passed = all(success for _, success in results)
111+
sys.exit(0 if all_passed else 1)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# This file makes the tests directory a Python package

0 commit comments

Comments
 (0)