Skip to content

Commit 152f565

Browse files
fix type codes by using Thrift ttypes
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent c23d540 commit 152f565

File tree

5 files changed

+96
-62
lines changed

5 files changed

+96
-62
lines changed

src/databricks/sql/backend/sea/utils/conversion.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -50,60 +50,65 @@ def _convert_decimal(
5050

5151
class SqlType:
5252
"""
53-
SQL type constants
53+
SQL type constants based on Thrift TTypeId values.
5454
55-
The list of types can be found in the SEA REST API Reference:
56-
https://docs.databricks.com/api/workspace/statementexecution/executestatement
55+
These correspond to the normalized type names that come from the SEA backend
56+
after normalize_sea_type_to_thrift processing (lowercase, without _TYPE suffix).
5757
"""
5858

5959
# Numeric types
60-
BYTE = "byte"
61-
SHORT = "short"
62-
INT = "int"
63-
LONG = "long"
64-
FLOAT = "float"
65-
DOUBLE = "double"
66-
DECIMAL = "decimal"
60+
TINYINT = "tinyint" # Maps to TTypeId.TINYINT_TYPE
61+
SMALLINT = "smallint" # Maps to TTypeId.SMALLINT_TYPE
62+
INT = "int" # Maps to TTypeId.INT_TYPE
63+
BIGINT = "bigint" # Maps to TTypeId.BIGINT_TYPE
64+
FLOAT = "float" # Maps to TTypeId.FLOAT_TYPE
65+
DOUBLE = "double" # Maps to TTypeId.DOUBLE_TYPE
66+
DECIMAL = "decimal" # Maps to TTypeId.DECIMAL_TYPE
6767

6868
# Boolean type
69-
BOOLEAN = "boolean"
69+
BOOLEAN = "boolean" # Maps to TTypeId.BOOLEAN_TYPE
7070

7171
# Date/Time types
72-
DATE = "date"
73-
TIMESTAMP = "timestamp"
74-
INTERVAL = "interval"
72+
DATE = "date" # Maps to TTypeId.DATE_TYPE
73+
TIMESTAMP = "timestamp" # Maps to TTypeId.TIMESTAMP_TYPE
74+
INTERVAL_YEAR_MONTH = (
75+
"interval_year_month" # Maps to TTypeId.INTERVAL_YEAR_MONTH_TYPE
76+
)
77+
INTERVAL_DAY_TIME = "interval_day_time" # Maps to TTypeId.INTERVAL_DAY_TIME_TYPE
7578

7679
# String types
77-
CHAR = "char"
78-
STRING = "string"
80+
CHAR = "char" # Maps to TTypeId.CHAR_TYPE
81+
VARCHAR = "varchar" # Maps to TTypeId.VARCHAR_TYPE
82+
STRING = "string" # Maps to TTypeId.STRING_TYPE
7983

8084
# Binary type
81-
BINARY = "binary"
85+
BINARY = "binary" # Maps to TTypeId.BINARY_TYPE
8286

8387
# Complex types
84-
ARRAY = "array"
85-
MAP = "map"
86-
STRUCT = "struct"
88+
ARRAY = "array" # Maps to TTypeId.ARRAY_TYPE
89+
MAP = "map" # Maps to TTypeId.MAP_TYPE
90+
STRUCT = "struct" # Maps to TTypeId.STRUCT_TYPE
8791

8892
# Other types
89-
NULL = "null"
90-
USER_DEFINED_TYPE = "user_defined_type"
93+
NULL = "null" # Maps to TTypeId.NULL_TYPE
94+
UNION = "union" # Maps to TTypeId.UNION_TYPE
95+
USER_DEFINED = "user_defined" # Maps to TTypeId.USER_DEFINED_TYPE
9196

9297

9398
class SqlTypeConverter:
9499
"""
95100
Utility class for converting SQL types to Python types.
96-
Based on the types supported by the Databricks SDK.
101+
Based on the Thrift TTypeId types after normalization.
97102
"""
98103

99104
# SQL type to conversion function mapping
100105
# TODO: complex types
101106
TYPE_MAPPING: Dict[str, Callable] = {
102107
# Numeric types
103-
SqlType.BYTE: lambda v: int(v),
104-
SqlType.SHORT: lambda v: int(v),
108+
SqlType.TINYINT: lambda v: int(v),
109+
SqlType.SMALLINT: lambda v: int(v),
105110
SqlType.INT: lambda v: int(v),
106-
SqlType.LONG: lambda v: int(v),
111+
SqlType.BIGINT: lambda v: int(v),
107112
SqlType.FLOAT: lambda v: float(v),
108113
SqlType.DOUBLE: lambda v: float(v),
109114
SqlType.DECIMAL: _convert_decimal,
@@ -112,16 +117,18 @@ class SqlTypeConverter:
112117
# Date/Time types
113118
SqlType.DATE: lambda v: datetime.date.fromisoformat(v),
114119
SqlType.TIMESTAMP: lambda v: parser.parse(v),
115-
SqlType.INTERVAL: lambda v: v, # Keep as string for now
120+
SqlType.INTERVAL_YEAR_MONTH: lambda v: v, # Keep as string for now
121+
SqlType.INTERVAL_DAY_TIME: lambda v: v, # Keep as string for now
116122
# String types - no conversion needed
117123
SqlType.CHAR: lambda v: v,
124+
SqlType.VARCHAR: lambda v: v,
118125
SqlType.STRING: lambda v: v,
119126
# Binary type
120127
SqlType.BINARY: lambda v: bytes.fromhex(v),
121128
# Other types
122129
SqlType.NULL: lambda v: None,
123130
# Complex types and user-defined types return as-is
124-
SqlType.USER_DEFINED_TYPE: lambda v: v,
131+
SqlType.USER_DEFINED: lambda v: v,
125132
}
126133

127134
@staticmethod
@@ -135,7 +142,7 @@ def convert_value(
135142
136143
Args:
137144
value: The string value to convert
138-
sql_type: The SQL type (e.g., 'int', 'decimal')
145+
sql_type: The SQL type (e.g., 'tinyint', 'decimal')
139146
**kwargs: Additional keyword arguments for the conversion function
140147
141148
Returns:

tests/unit/test_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,7 @@ def test_negative_fetch_throws_exception(self):
265265
mock_backend = Mock()
266266
mock_backend.fetch_results.return_value = (Mock(), False, 0)
267267

268-
result_set = ThriftResultSet(
269-
Mock(), Mock(), mock_backend
270-
)
268+
result_set = ThriftResultSet(Mock(), Mock(), mock_backend)
271269

272270
with self.assertRaises(ValueError) as e:
273271
result_set.fetchmany(-1)

tests/unit/test_downloader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ class DownloaderTests(unittest.TestCase):
2626
def _setup_time_mock_for_download(self, mock_time, end_time):
2727
"""Helper to setup time mock that handles logging system calls."""
2828
call_count = [0]
29+
2930
def time_side_effect():
3031
call_count[0] += 1
3132
if call_count[0] <= 2: # First two calls (validation, start_time)
3233
return 1000
3334
else: # All subsequent calls (logging, duration calculation)
3435
return end_time
36+
3537
mock_time.side_effect = time_side_effect
3638

3739
@patch("time.time", return_value=1000)
@@ -104,7 +106,7 @@ def test_run_get_response_not_ok(self, mock_time):
104106
@patch("time.time")
105107
def test_run_uncompressed_successful(self, mock_time):
106108
self._setup_time_mock_for_download(mock_time, 1000.5)
107-
109+
108110
http_client = DatabricksHttpClient.get_instance()
109111
file_bytes = b"1234567890" * 10
110112
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
@@ -133,7 +135,7 @@ def test_run_uncompressed_successful(self, mock_time):
133135
@patch("time.time")
134136
def test_run_compressed_successful(self, mock_time):
135137
self._setup_time_mock_for_download(mock_time, 1000.2)
136-
138+
137139
http_client = DatabricksHttpClient.get_instance()
138140
file_bytes = b"1234567890" * 10
139141
compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'

tests/unit/test_sea_conversion.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ class TestSqlTypeConverter:
1818
def test_convert_numeric_types(self):
1919
"""Test converting numeric types."""
2020
# Test integer types
21-
assert SqlTypeConverter.convert_value("123", SqlType.BYTE) == 123
22-
assert SqlTypeConverter.convert_value("456", SqlType.SHORT) == 456
21+
assert SqlTypeConverter.convert_value("123", SqlType.TINYINT) == 123
22+
assert SqlTypeConverter.convert_value("456", SqlType.SMALLINT) == 456
2323
assert SqlTypeConverter.convert_value("789", SqlType.INT) == 789
24-
assert SqlTypeConverter.convert_value("1234567890", SqlType.LONG) == 1234567890
24+
assert (
25+
SqlTypeConverter.convert_value("1234567890", SqlType.BIGINT) == 1234567890
26+
)
2527

2628
# Test floating point types
2729
assert SqlTypeConverter.convert_value("123.45", SqlType.FLOAT) == 123.45
@@ -80,11 +82,16 @@ def test_convert_datetime_types(self):
8082
assert timestamp_value.minute == 30
8183
assert timestamp_value.second == 45
8284

83-
# Test interval type (currently returns as string)
84-
interval_value = SqlTypeConverter.convert_value(
85-
"1 day 2 hours", SqlType.INTERVAL
85+
# Test interval types (currently return as string)
86+
interval_ym_value = SqlTypeConverter.convert_value(
87+
"1-6", SqlType.INTERVAL_YEAR_MONTH
88+
)
89+
assert interval_ym_value == "1-6"
90+
91+
interval_dt_value = SqlTypeConverter.convert_value(
92+
"1 day 2 hours", SqlType.INTERVAL_DAY_TIME
8693
)
87-
assert interval_value == "1 day 2 hours"
94+
assert interval_dt_value == "1 day 2 hours"
8895

8996
# Test invalid date input
9097
result = SqlTypeConverter.convert_value("not_a_date", SqlType.DATE)
@@ -98,6 +105,10 @@ def test_convert_string_types(self):
98105
== "test string"
99106
)
100107
assert SqlTypeConverter.convert_value("test char", SqlType.CHAR) == "test char"
108+
assert (
109+
SqlTypeConverter.convert_value("test varchar", SqlType.VARCHAR)
110+
== "test varchar"
111+
)
101112

102113
def test_convert_binary_type(self):
103114
"""Test converting binary type."""
@@ -115,7 +126,7 @@ def test_convert_unsupported_type(self):
115126
# Should return the original value
116127
assert SqlTypeConverter.convert_value("test", "unsupported_type") == "test"
117128

118-
# Complex types should return as-is
129+
# Complex types should return as-is (not yet implemented in TYPE_MAPPING)
119130
assert (
120131
SqlTypeConverter.convert_value("complex_value", SqlType.ARRAY)
121132
== "complex_value"

tests/unit/test_telemetry_retry.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
77
from databricks.sql.auth.retry import DatabricksRetryPolicy
88

9-
PATCH_TARGET = 'urllib3.connectionpool.HTTPSConnectionPool._get_conn'
9+
PATCH_TARGET = "urllib3.connectionpool.HTTPSConnectionPool._get_conn"
10+
1011

1112
def create_mock_conn(responses):
1213
"""Creates a mock connection object whose getresponse() method yields a series of responses."""
@@ -16,15 +17,18 @@ def create_mock_conn(responses):
1617
mock_http_response = MagicMock()
1718
mock_http_response.status = resp.get("status")
1819
mock_http_response.headers = resp.get("headers", {})
19-
body = resp.get("body", b'{}')
20+
body = resp.get("body", b"{}")
2021
mock_http_response.fp = io.BytesIO(body)
22+
2123
def release():
2224
mock_http_response.fp.close()
25+
2326
mock_http_response.release_conn = release
2427
mock_http_responses.append(mock_http_response)
2528
mock_conn.getresponse.side_effect = mock_http_responses
2629
return mock_conn
2730

31+
2832
class TestTelemetryClientRetries:
2933
@pytest.fixture(autouse=True)
3034
def setup_and_teardown(self):
@@ -49,28 +53,28 @@ def get_client(self, session_id, num_retries=3):
4953
host_url="test.databricks.com",
5054
)
5155
client = TelemetryClientFactory.get_telemetry_client(session_id)
52-
56+
5357
retry_policy = DatabricksRetryPolicy(
5458
delay_min=0.01,
5559
delay_max=0.02,
5660
stop_after_attempts_duration=2.0,
57-
stop_after_attempts_count=num_retries,
61+
stop_after_attempts_count=num_retries,
5862
delay_default=0.1,
5963
force_dangerous_codes=[],
60-
urllib3_kwargs={'total': num_retries}
64+
urllib3_kwargs={"total": num_retries},
6165
)
6266
adapter = client._http_client.session.adapters.get("https://")
6367
adapter.max_retries = retry_policy
6468
return client
6569

6670
@pytest.mark.parametrize(
67-
"status_code, description",
68-
[
69-
(401, "Unauthorized"),
70-
(403, "Forbidden"),
71-
(501, "Not Implemented"),
72-
(200, "Success"),
73-
],
71+
"status_code, description",
72+
[
73+
(401, "Unauthorized"),
74+
(403, "Forbidden"),
75+
(501, "Not Implemented"),
76+
(200, "Success"),
77+
],
7478
)
7579
def test_non_retryable_status_codes_are_not_retried(self, status_code, description):
7680
"""
@@ -80,7 +84,9 @@ def test_non_retryable_status_codes_are_not_retried(self, status_code, descripti
8084
client = self.get_client(f"session-{status_code}")
8185
mock_responses = [{"status": status_code}]
8286

83-
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
87+
with patch(
88+
PATCH_TARGET, return_value=create_mock_conn(mock_responses)
89+
) as mock_get_conn:
8490
client.export_failure_log("TestError", "Test message")
8591
TelemetryClientFactory.close(client._session_id_hex)
8692

@@ -92,16 +98,26 @@ def test_exceeds_retry_count_limit(self):
9298
Verifies that the client respects the Retry-After header and retries on 429, 502, 503.
9399
"""
94100
num_retries = 3
95-
expected_total_calls = num_retries + 1
101+
expected_total_calls = num_retries + 1
96102
retry_after = 1
97103
client = self.get_client("session-exceed-limit", num_retries=num_retries)
98-
mock_responses = [{"status": 503, "headers": {"Retry-After": str(retry_after)}}, {"status": 429}, {"status": 502}, {"status": 503}]
99-
100-
with patch(PATCH_TARGET, return_value=create_mock_conn(mock_responses)) as mock_get_conn:
104+
mock_responses = [
105+
{"status": 503, "headers": {"Retry-After": str(retry_after)}},
106+
{"status": 429},
107+
{"status": 502},
108+
{"status": 503},
109+
]
110+
111+
with patch(
112+
PATCH_TARGET, return_value=create_mock_conn(mock_responses)
113+
) as mock_get_conn:
101114
start_time = time.time()
102115
client.export_failure_log("TestError", "Test message")
103116
TelemetryClientFactory.close(client._session_id_hex)
104117
end_time = time.time()
105-
106-
assert mock_get_conn.return_value.getresponse.call_count == expected_total_calls
107-
assert end_time - start_time > retry_after
118+
119+
assert (
120+
mock_get_conn.return_value.getresponse.call_count
121+
== expected_total_calls
122+
)
123+
assert end_time - start_time > retry_after

0 commit comments

Comments
 (0)