Skip to content

Commit 7789866

Browse files
committed
fix: linting and tests
1 parent 6d9399d commit 7789866

File tree

23 files changed

+1293
-452
lines changed

23 files changed

+1293
-452
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: mixed-line-ending
1818
- id: trailing-whitespace
1919
- repo: https://github.com/charliermarsh/ruff-pre-commit
20-
rev: "v0.13.0"
20+
rev: "v0.13.1"
2121
hooks:
2222
- id: ruff
2323
args: ["--fix"]

sqlspec/adapters/adbc/data_dictionary.py

Lines changed: 88 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from sqlspec.utils.logging import get_logger
88

99
if TYPE_CHECKING:
10+
from collections.abc import Callable
11+
1012
from sqlspec.adapters.adbc.driver import AdbcDriver
1113

1214
logger = get_logger("adapters.adbc.data_dictionary")
@@ -22,30 +24,19 @@
2224
class AdbcDataDictionary(SyncDataDictionaryBase):
2325
"""ADBC multi-dialect data dictionary.
2426
25-
Detects the underlying database dialect and delegates to appropriate logic.
27+
Delegates to appropriate dialect-specific logic based on the driver's dialect.
2628
"""
2729

28-
def __init__(self) -> None:
29-
"""Initialize ADBC data dictionary."""
30-
self._detected_dialect: Optional[str] = None
31-
self._cached_version: Optional[VersionInfo] = None
32-
33-
def _detect_dialect(self, driver: SyncDriverAdapterBase) -> str:
34-
"""Detect the underlying database dialect.
30+
def _get_dialect(self, driver: SyncDriverAdapterBase) -> str:
31+
"""Get dialect from ADBC driver.
3532
3633
Args:
3734
driver: ADBC driver instance
3835
3936
Returns:
40-
Detected dialect name
37+
Dialect name
4138
"""
42-
if self._detected_dialect:
43-
return self._detected_dialect
44-
45-
self._detected_dialect = (
46-
str(cast("AdbcDriver", driver).dialect) if cast("AdbcDriver", driver).dialect else "sqlite"
47-
)
48-
return self._detected_dialect
39+
return str(cast("AdbcDriver", driver).dialect)
4940

5041
def get_version(self, driver: SyncDriverAdapterBase) -> "Optional[VersionInfo]":
5142
"""Get database version information based on detected dialect.
@@ -56,59 +47,51 @@ def get_version(self, driver: SyncDriverAdapterBase) -> "Optional[VersionInfo]":
5647
Returns:
5748
Database version information or None if detection fails
5849
"""
59-
if self._cached_version:
60-
return self._cached_version
61-
62-
dialect = self._detect_dialect(driver)
63-
64-
if dialect == "postgres":
65-
try:
66-
version_str = cast("AdbcDriver", driver).select_value("SELECT version()")
67-
version_match = POSTGRES_VERSION_PATTERN.search(str(version_str))
68-
if version_match:
69-
major = int(version_match.group(1))
70-
minor = int(version_match.group(2))
71-
patch = int(version_match.group(3)) if version_match.group(3) else 0
72-
self._cached_version = VersionInfo(major, minor, patch)
73-
except Exception:
74-
logger.warning("Failed to get PostgreSQL version")
75-
76-
elif dialect == "sqlite":
77-
try:
78-
version_str = cast("AdbcDriver", driver).select_value("SELECT sqlite_version()")
79-
version_match = SQLITE_VERSION_PATTERN.match(str(version_str))
80-
if version_match:
81-
major, minor, patch = map(int, version_match.groups())
82-
self._cached_version = VersionInfo(major, minor, patch)
83-
except Exception:
84-
logger.warning("Failed to get SQLite version")
85-
86-
elif dialect == "duckdb":
87-
try:
88-
version_str = cast("AdbcDriver", driver).select_value("SELECT version()")
89-
version_match = DUCKDB_VERSION_PATTERN.search(str(version_str))
90-
if version_match:
91-
major, minor, patch = map(int, version_match.groups())
92-
self._cached_version = VersionInfo(major, minor, patch)
93-
except Exception:
94-
logger.warning("Failed to get DuckDB version")
95-
96-
elif dialect == "mysql":
97-
try:
98-
version_str = cast("AdbcDriver", driver).select_value("SELECT VERSION()")
99-
version_match = MYSQL_VERSION_PATTERN.search(str(version_str))
100-
if version_match:
101-
major, minor, patch = map(int, version_match.groups())
102-
self._cached_version = VersionInfo(major, minor, patch)
103-
except Exception:
104-
logger.warning("Failed to get MySQL version")
105-
106-
elif dialect == "bigquery":
107-
# BigQuery is a cloud service
108-
self._cached_version = VersionInfo(1, 0, 0)
109-
110-
logger.debug("Detected %s version: %s", dialect, self._cached_version)
111-
return self._cached_version
50+
dialect = self._get_dialect(driver)
51+
adbc_driver = cast("AdbcDriver", driver)
52+
53+
try:
54+
if dialect == "postgres":
55+
version_str = adbc_driver.select_value("SELECT version()")
56+
if version_str:
57+
match = POSTGRES_VERSION_PATTERN.search(str(version_str))
58+
if match:
59+
major = int(match.group(1))
60+
minor = int(match.group(2))
61+
patch = int(match.group(3)) if match.group(3) else 0
62+
return VersionInfo(major, minor, patch)
63+
64+
elif dialect == "sqlite":
65+
version_str = adbc_driver.select_value("SELECT sqlite_version()")
66+
if version_str:
67+
match = SQLITE_VERSION_PATTERN.match(str(version_str))
68+
if match:
69+
major, minor, patch = map(int, match.groups())
70+
return VersionInfo(major, minor, patch)
71+
72+
elif dialect == "duckdb":
73+
version_str = adbc_driver.select_value("SELECT version()")
74+
if version_str:
75+
match = DUCKDB_VERSION_PATTERN.search(str(version_str))
76+
if match:
77+
major, minor, patch = map(int, match.groups())
78+
return VersionInfo(major, minor, patch)
79+
80+
elif dialect == "mysql":
81+
version_str = adbc_driver.select_value("SELECT VERSION()")
82+
if version_str:
83+
match = MYSQL_VERSION_PATTERN.search(str(version_str))
84+
if match:
85+
major, minor, patch = map(int, match.groups())
86+
return VersionInfo(major, minor, patch)
87+
88+
elif dialect == "bigquery":
89+
return VersionInfo(1, 0, 0)
90+
91+
except Exception:
92+
logger.warning("Failed to get %s version", dialect)
93+
94+
return None
11295

11396
def get_feature_flag(self, driver: SyncDriverAdapterBase, feature: str) -> bool:
11497
"""Check if database supports a specific feature based on detected dialect.
@@ -120,25 +103,35 @@ def get_feature_flag(self, driver: SyncDriverAdapterBase, feature: str) -> bool:
120103
Returns:
121104
True if feature is supported, False otherwise
122105
"""
123-
dialect = self._detect_dialect(driver)
106+
dialect = self._get_dialect(driver)
124107
version_info = self.get_version(driver)
125108

126109
if dialect == "postgres":
127-
feature_checks = {
110+
feature_checks: dict[str, Callable[..., bool]] = {
128111
"supports_json": lambda v: v and v >= VersionInfo(9, 2, 0),
129112
"supports_jsonb": lambda v: v and v >= VersionInfo(9, 4, 0),
130113
"supports_uuid": lambda _: True,
131114
"supports_arrays": lambda _: True,
132115
"supports_returning": lambda v: v and v >= VersionInfo(8, 2, 0),
133116
"supports_upsert": lambda v: v and v >= VersionInfo(9, 5, 0),
117+
"supports_window_functions": lambda v: v and v >= VersionInfo(8, 4, 0),
118+
"supports_cte": lambda v: v and v >= VersionInfo(8, 4, 0),
119+
"supports_transactions": lambda _: True,
120+
"supports_prepared_statements": lambda _: True,
121+
"supports_schemas": lambda _: True,
134122
}
135123
elif dialect == "sqlite":
136124
feature_checks = {
137125
"supports_json": lambda v: v and v >= VersionInfo(3, 38, 0),
138126
"supports_returning": lambda v: v and v >= VersionInfo(3, 35, 0),
139127
"supports_upsert": lambda v: v and v >= VersionInfo(3, 24, 0),
140-
"supports_uuid": lambda _: False,
128+
"supports_window_functions": lambda v: v and v >= VersionInfo(3, 25, 0),
129+
"supports_cte": lambda v: v and v >= VersionInfo(3, 8, 3),
130+
"supports_transactions": lambda _: True,
131+
"supports_prepared_statements": lambda _: True,
132+
"supports_schemas": lambda _: False,
141133
"supports_arrays": lambda _: False,
134+
"supports_uuid": lambda _: False,
142135
}
143136
elif dialect == "duckdb":
144137
feature_checks = {
@@ -147,13 +140,22 @@ def get_feature_flag(self, driver: SyncDriverAdapterBase, feature: str) -> bool:
147140
"supports_uuid": lambda _: True,
148141
"supports_returning": lambda v: v and v >= VersionInfo(0, 8, 0),
149142
"supports_upsert": lambda v: v and v >= VersionInfo(0, 8, 0),
143+
"supports_window_functions": lambda _: True,
144+
"supports_cte": lambda _: True,
145+
"supports_transactions": lambda _: True,
146+
"supports_prepared_statements": lambda _: True,
147+
"supports_schemas": lambda _: True,
150148
}
151149
elif dialect == "mysql":
152150
feature_checks = {
153151
"supports_json": lambda v: v and v >= VersionInfo(5, 7, 8),
154152
"supports_cte": lambda v: v and v >= VersionInfo(8, 0, 1),
155153
"supports_returning": lambda _: False,
156154
"supports_upsert": lambda _: True,
155+
"supports_window_functions": lambda v: v and v >= VersionInfo(8, 0, 2),
156+
"supports_transactions": lambda _: True,
157+
"supports_prepared_statements": lambda _: True,
158+
"supports_schemas": lambda _: True,
157159
"supports_uuid": lambda _: False,
158160
"supports_arrays": lambda _: False,
159161
}
@@ -164,23 +166,23 @@ def get_feature_flag(self, driver: SyncDriverAdapterBase, feature: str) -> bool:
164166
"supports_structs": lambda _: True,
165167
"supports_returning": lambda _: False,
166168
"supports_upsert": lambda _: True,
169+
"supports_window_functions": lambda _: True,
170+
"supports_cte": lambda _: True,
171+
"supports_transactions": lambda _: False,
172+
"supports_prepared_statements": lambda _: True,
173+
"supports_schemas": lambda _: True,
167174
"supports_uuid": lambda _: False,
168175
}
169176
else:
170-
feature_checks = {}
171-
172-
# Common features
173-
common_features = {
174-
"supports_transactions": lambda _: True,
175-
"supports_prepared_statements": lambda _: True,
176-
"supports_window_functions": lambda _: True,
177-
"supports_cte": lambda _: True,
178-
}
179-
180-
feature_checks.update(common_features)
177+
feature_checks = {
178+
"supports_transactions": lambda _: True,
179+
"supports_prepared_statements": lambda _: True,
180+
"supports_window_functions": lambda _: True,
181+
"supports_cte": lambda _: True,
182+
}
181183

182184
if feature in feature_checks:
183-
return feature_checks[feature](version_info)
185+
return bool(feature_checks[feature](version_info))
184186

185187
return False
186188

@@ -194,7 +196,7 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) ->
194196
Returns:
195197
Database-specific type name
196198
"""
197-
dialect = self._detect_dialect(driver)
199+
dialect = self._get_dialect(driver)
198200
version_info = self.get_version(driver)
199201

200202
if dialect == "postgres":
@@ -210,6 +212,7 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) ->
210212
"timestamp": "TIMESTAMP WITH TIME ZONE",
211213
"text": "TEXT",
212214
"blob": "BYTEA",
215+
"array": "ARRAY",
213216
}
214217

215218
elif dialect == "sqlite":
@@ -254,7 +257,6 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) ->
254257
"array": "ARRAY",
255258
}
256259
else:
257-
# Generic fallback
258260
type_map = {
259261
"json": "TEXT",
260262
"uuid": "VARCHAR(36)",

sqlspec/adapters/aiosqlite/data_dictionary.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""SQLite-specific data dictionary for metadata queries via aiosqlite."""
22

33
import re
4-
from typing import TYPE_CHECKING, Optional, cast
4+
from typing import TYPE_CHECKING, Callable, Optional, cast
55

66
from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, VersionInfo
77
from sqlspec.utils.logging import get_logger
@@ -59,7 +59,7 @@ async def get_feature_flag(self, driver: AsyncDriverAdapterBase, feature: str) -
5959
if not version_info:
6060
return False
6161

62-
feature_checks = {
62+
feature_checks: dict[str, Callable[..., bool]] = {
6363
"supports_json": lambda v: v >= VersionInfo(3, 38, 0),
6464
"supports_returning": lambda v: v >= VersionInfo(3, 35, 0),
6565
"supports_upsert": lambda v: v >= VersionInfo(3, 24, 0),
@@ -73,7 +73,7 @@ async def get_feature_flag(self, driver: AsyncDriverAdapterBase, feature: str) -
7373
}
7474

7575
if feature in feature_checks:
76-
return feature_checks[feature](version_info)
76+
return bool(feature_checks[feature](version_info))
7777

7878
return False
7979

sqlspec/adapters/asyncmy/data_dictionary.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""MySQL-specific data dictionary for metadata queries via asyncmy."""
22

33
import re
4-
from typing import TYPE_CHECKING, Optional, cast
4+
from typing import TYPE_CHECKING, Callable, Optional, cast
55

66
from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, VersionInfo
77
from sqlspec.utils.logging import get_logger
@@ -58,7 +58,7 @@ async def get_feature_flag(self, driver: AsyncDriverAdapterBase, feature: str) -
5858
if not version_info:
5959
return False
6060

61-
feature_checks = {
61+
feature_checks: dict[str, Callable[..., bool]] = {
6262
"supports_json": lambda v: v >= VersionInfo(5, 7, 8),
6363
"supports_cte": lambda v: v >= VersionInfo(8, 0, 1),
6464
"supports_window_functions": lambda v: v >= VersionInfo(8, 0, 2),
@@ -72,7 +72,7 @@ async def get_feature_flag(self, driver: AsyncDriverAdapterBase, feature: str) -
7272
}
7373

7474
if feature in feature_checks:
75-
return feature_checks[feature](version_info)
75+
return bool(feature_checks[feature](version_info))
7676

7777
return False
7878

sqlspec/adapters/asyncpg/data_dictionary.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""PostgreSQL-specific data dictionary for metadata queries via asyncpg."""
22

33
import re
4-
from typing import TYPE_CHECKING, Optional, cast
4+
from typing import TYPE_CHECKING, Callable, Optional, cast
55

66
from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, VersionInfo
77
from sqlspec.utils.logging import get_logger
@@ -63,7 +63,7 @@ async def get_feature_flag(self, driver: AsyncDriverAdapterBase, feature: str) -
6363
if not version_info:
6464
return False
6565

66-
feature_checks = {
66+
feature_checks: dict[str, Callable[..., bool]] = {
6767
"supports_json": lambda v: v >= VersionInfo(9, 2, 0),
6868
"supports_jsonb": lambda v: v >= VersionInfo(9, 4, 0),
6969
"supports_uuid": lambda _: True, # UUID extension widely available
@@ -79,7 +79,7 @@ async def get_feature_flag(self, driver: AsyncDriverAdapterBase, feature: str) -
7979
}
8080

8181
if feature in feature_checks:
82-
return feature_checks[feature](version_info)
82+
return bool(feature_checks[feature](version_info))
8383

8484
return False
8585

sqlspec/adapters/duckdb/_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import TYPE_CHECKING
22

3-
from duckdb import DuckDBPyConnection
3+
from duckdb import DuckDBPyConnection # type: ignore[import-untyped]
44

55
if TYPE_CHECKING:
66
from typing_extensions import TypeAlias

sqlspec/adapters/duckdb/data_dictionary.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""DuckDB-specific data dictionary for metadata queries."""
22

33
import re
4-
from typing import TYPE_CHECKING, Optional, cast
4+
from typing import TYPE_CHECKING, Callable, Optional, cast
55

66
from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo
77
from sqlspec.utils.logging import get_logger
@@ -59,7 +59,7 @@ def get_feature_flag(self, driver: SyncDriverAdapterBase, feature: str) -> bool:
5959
if not version_info:
6060
return False
6161

62-
feature_checks = {
62+
feature_checks: dict[str, Callable[..., bool]] = {
6363
"supports_json": lambda _: True, # DuckDB has excellent JSON support
6464
"supports_arrays": lambda _: True, # LIST type
6565
"supports_maps": lambda _: True, # MAP type
@@ -75,7 +75,7 @@ def get_feature_flag(self, driver: SyncDriverAdapterBase, feature: str) -> bool:
7575
}
7676

7777
if feature in feature_checks:
78-
return feature_checks[feature](version_info)
78+
return bool(feature_checks[feature](version_info))
7979

8080
return False
8181

0 commit comments

Comments
 (0)