Skip to content

Commit 7a75762

Browse files
authored
feat: add data dictionary support for database metadata (#82)
Comprehensive data dictionary implementation for all adapters: - VersionInfo class for database version tracking and comparison - DataDictionaryMixin providing common functionality - AsyncDataDictionaryBase and SyncDataDictionaryBase abstract classes - Driver integration with data_dictionary property Adapter-specific implementations for all 10 database adapters: ADBC, AIOSQLite, AsyncMy, AsyncPG, BigQuery, DuckDB, OracleDB, psqlpy, Psycopg, SQLite Features: - Version detection and parsing - Feature flag checking (JSON support, RETURNING, CTEs, etc.) - Optimal type mapping for different type categories - Database introspection capabilities
1 parent f597baf commit 7a75762

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3803
-385
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: mixed-line-ending
1818
- id: trailing-whitespace
1919
- repo: https://github.com/charliermarsh/ruff-pre-commit
20-
rev: "v0.13.0"
20+
rev: "v0.13.1"
2121
hooks:
2222
- id: ruff
2323
args: ["--fix"]

sqlspec/adapters/adbc/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,11 @@ def _get_connect_func(self) -> Callable[..., AdbcConnection]:
174174
try:
175175
connect_func = import_string(driver_path)
176176
except ImportError as e:
177-
driver_path_with_suffix = f"{driver_path}.dbapi.connect"
177+
# Only add .dbapi.connect if it's not already there
178+
if not driver_path.endswith(".dbapi.connect"):
179+
driver_path_with_suffix = f"{driver_path}.dbapi.connect"
180+
else:
181+
driver_path_with_suffix = driver_path
178182
try:
179183
connect_func = import_string(driver_path_with_suffix)
180184
except ImportError as e2:
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
"""ADBC multi-dialect data dictionary for metadata queries."""
2+
3+
import re
4+
from typing import TYPE_CHECKING, Optional, cast
5+
6+
from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo
7+
from sqlspec.utils.logging import get_logger
8+
9+
if TYPE_CHECKING:
10+
from collections.abc import Callable
11+
12+
from sqlspec.adapters.adbc.driver import AdbcDriver
13+
14+
logger = get_logger("adapters.adbc.data_dictionary")
15+
16+
POSTGRES_VERSION_PATTERN = re.compile(r"PostgreSQL (\d+)\.(\d+)(?:\.(\d+))?")
17+
SQLITE_VERSION_PATTERN = re.compile(r"(\d+)\.(\d+)\.(\d+)")
18+
DUCKDB_VERSION_PATTERN = re.compile(r"v?(\d+)\.(\d+)\.(\d+)")
19+
MYSQL_VERSION_PATTERN = re.compile(r"(\d+)\.(\d+)\.(\d+)")
20+
21+
__all__ = ("AdbcDataDictionary",)
22+
23+
24+
class AdbcDataDictionary(SyncDataDictionaryBase):
25+
"""ADBC multi-dialect data dictionary.
26+
27+
Delegates to appropriate dialect-specific logic based on the driver's dialect.
28+
"""
29+
30+
def _get_dialect(self, driver: SyncDriverAdapterBase) -> str:
31+
"""Get dialect from ADBC driver.
32+
33+
Args:
34+
driver: ADBC driver instance
35+
36+
Returns:
37+
Dialect name
38+
"""
39+
return str(cast("AdbcDriver", driver).dialect)
40+
41+
def get_version(self, driver: SyncDriverAdapterBase) -> "Optional[VersionInfo]":
42+
"""Get database version information based on detected dialect.
43+
44+
Args:
45+
driver: ADBC driver instance
46+
47+
Returns:
48+
Database version information or None if detection fails
49+
"""
50+
dialect = self._get_dialect(driver)
51+
adbc_driver = cast("AdbcDriver", driver)
52+
53+
try:
54+
if dialect == "postgres":
55+
version_str = adbc_driver.select_value("SELECT version()")
56+
if version_str:
57+
match = POSTGRES_VERSION_PATTERN.search(str(version_str))
58+
if match:
59+
major = int(match.group(1))
60+
minor = int(match.group(2))
61+
patch = int(match.group(3)) if match.group(3) else 0
62+
return VersionInfo(major, minor, patch)
63+
64+
elif dialect == "sqlite":
65+
version_str = adbc_driver.select_value("SELECT sqlite_version()")
66+
if version_str:
67+
match = SQLITE_VERSION_PATTERN.match(str(version_str))
68+
if match:
69+
major, minor, patch = map(int, match.groups())
70+
return VersionInfo(major, minor, patch)
71+
72+
elif dialect == "duckdb":
73+
version_str = adbc_driver.select_value("SELECT version()")
74+
if version_str:
75+
match = DUCKDB_VERSION_PATTERN.search(str(version_str))
76+
if match:
77+
major, minor, patch = map(int, match.groups())
78+
return VersionInfo(major, minor, patch)
79+
80+
elif dialect == "mysql":
81+
version_str = adbc_driver.select_value("SELECT VERSION()")
82+
if version_str:
83+
match = MYSQL_VERSION_PATTERN.search(str(version_str))
84+
if match:
85+
major, minor, patch = map(int, match.groups())
86+
return VersionInfo(major, minor, patch)
87+
88+
elif dialect == "bigquery":
89+
return VersionInfo(1, 0, 0)
90+
91+
except Exception:
92+
logger.warning("Failed to get %s version", dialect)
93+
94+
return None
95+
96+
def get_feature_flag(self, driver: SyncDriverAdapterBase, feature: str) -> bool:
97+
"""Check if database supports a specific feature based on detected dialect.
98+
99+
Args:
100+
driver: ADBC driver instance
101+
feature: Feature name to check
102+
103+
Returns:
104+
True if feature is supported, False otherwise
105+
"""
106+
dialect = self._get_dialect(driver)
107+
version_info = self.get_version(driver)
108+
109+
if dialect == "postgres":
110+
feature_checks: dict[str, Callable[..., bool]] = {
111+
"supports_json": lambda v: v and v >= VersionInfo(9, 2, 0),
112+
"supports_jsonb": lambda v: v and v >= VersionInfo(9, 4, 0),
113+
"supports_uuid": lambda _: True,
114+
"supports_arrays": lambda _: True,
115+
"supports_returning": lambda v: v and v >= VersionInfo(8, 2, 0),
116+
"supports_upsert": lambda v: v and v >= VersionInfo(9, 5, 0),
117+
"supports_window_functions": lambda v: v and v >= VersionInfo(8, 4, 0),
118+
"supports_cte": lambda v: v and v >= VersionInfo(8, 4, 0),
119+
"supports_transactions": lambda _: True,
120+
"supports_prepared_statements": lambda _: True,
121+
"supports_schemas": lambda _: True,
122+
}
123+
elif dialect == "sqlite":
124+
feature_checks = {
125+
"supports_json": lambda v: v and v >= VersionInfo(3, 38, 0),
126+
"supports_returning": lambda v: v and v >= VersionInfo(3, 35, 0),
127+
"supports_upsert": lambda v: v and v >= VersionInfo(3, 24, 0),
128+
"supports_window_functions": lambda v: v and v >= VersionInfo(3, 25, 0),
129+
"supports_cte": lambda v: v and v >= VersionInfo(3, 8, 3),
130+
"supports_transactions": lambda _: True,
131+
"supports_prepared_statements": lambda _: True,
132+
"supports_schemas": lambda _: False,
133+
"supports_arrays": lambda _: False,
134+
"supports_uuid": lambda _: False,
135+
}
136+
elif dialect == "duckdb":
137+
feature_checks = {
138+
"supports_json": lambda _: True,
139+
"supports_arrays": lambda _: True,
140+
"supports_uuid": lambda _: True,
141+
"supports_returning": lambda v: v and v >= VersionInfo(0, 8, 0),
142+
"supports_upsert": lambda v: v and v >= VersionInfo(0, 8, 0),
143+
"supports_window_functions": lambda _: True,
144+
"supports_cte": lambda _: True,
145+
"supports_transactions": lambda _: True,
146+
"supports_prepared_statements": lambda _: True,
147+
"supports_schemas": lambda _: True,
148+
}
149+
elif dialect == "mysql":
150+
feature_checks = {
151+
"supports_json": lambda v: v and v >= VersionInfo(5, 7, 8),
152+
"supports_cte": lambda v: v and v >= VersionInfo(8, 0, 1),
153+
"supports_returning": lambda _: False,
154+
"supports_upsert": lambda _: True,
155+
"supports_window_functions": lambda v: v and v >= VersionInfo(8, 0, 2),
156+
"supports_transactions": lambda _: True,
157+
"supports_prepared_statements": lambda _: True,
158+
"supports_schemas": lambda _: True,
159+
"supports_uuid": lambda _: False,
160+
"supports_arrays": lambda _: False,
161+
}
162+
elif dialect == "bigquery":
163+
feature_checks = {
164+
"supports_json": lambda _: True,
165+
"supports_arrays": lambda _: True,
166+
"supports_structs": lambda _: True,
167+
"supports_returning": lambda _: False,
168+
"supports_upsert": lambda _: True,
169+
"supports_window_functions": lambda _: True,
170+
"supports_cte": lambda _: True,
171+
"supports_transactions": lambda _: False,
172+
"supports_prepared_statements": lambda _: True,
173+
"supports_schemas": lambda _: True,
174+
"supports_uuid": lambda _: False,
175+
}
176+
else:
177+
feature_checks = {
178+
"supports_transactions": lambda _: True,
179+
"supports_prepared_statements": lambda _: True,
180+
"supports_window_functions": lambda _: True,
181+
"supports_cte": lambda _: True,
182+
}
183+
184+
if feature in feature_checks:
185+
return bool(feature_checks[feature](version_info))
186+
187+
return False
188+
189+
def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) -> str:
190+
"""Get optimal database type for a category based on detected dialect.
191+
192+
Args:
193+
driver: ADBC driver instance
194+
type_category: Type category
195+
196+
Returns:
197+
Database-specific type name
198+
"""
199+
dialect = self._get_dialect(driver)
200+
version_info = self.get_version(driver)
201+
202+
if dialect == "postgres":
203+
if type_category == "json":
204+
if version_info and version_info >= VersionInfo(9, 4, 0):
205+
return "JSONB"
206+
if version_info and version_info >= VersionInfo(9, 2, 0):
207+
return "JSON"
208+
return "TEXT"
209+
type_map = {
210+
"uuid": "UUID",
211+
"boolean": "BOOLEAN",
212+
"timestamp": "TIMESTAMP WITH TIME ZONE",
213+
"text": "TEXT",
214+
"blob": "BYTEA",
215+
"array": "ARRAY",
216+
}
217+
218+
elif dialect == "sqlite":
219+
if type_category == "json":
220+
if version_info and version_info >= VersionInfo(3, 38, 0):
221+
return "JSON"
222+
return "TEXT"
223+
type_map = {"uuid": "TEXT", "boolean": "INTEGER", "timestamp": "TIMESTAMP", "text": "TEXT", "blob": "BLOB"}
224+
225+
elif dialect == "duckdb":
226+
type_map = {
227+
"json": "JSON",
228+
"uuid": "UUID",
229+
"boolean": "BOOLEAN",
230+
"timestamp": "TIMESTAMP",
231+
"text": "TEXT",
232+
"blob": "BLOB",
233+
"array": "LIST",
234+
}
235+
236+
elif dialect == "mysql":
237+
if type_category == "json":
238+
if version_info and version_info >= VersionInfo(5, 7, 8):
239+
return "JSON"
240+
return "TEXT"
241+
type_map = {
242+
"uuid": "VARCHAR(36)",
243+
"boolean": "TINYINT(1)",
244+
"timestamp": "TIMESTAMP",
245+
"text": "TEXT",
246+
"blob": "BLOB",
247+
}
248+
249+
elif dialect == "bigquery":
250+
type_map = {
251+
"json": "JSON",
252+
"uuid": "STRING",
253+
"boolean": "BOOL",
254+
"timestamp": "TIMESTAMP",
255+
"text": "STRING",
256+
"blob": "BYTES",
257+
"array": "ARRAY",
258+
}
259+
else:
260+
type_map = {
261+
"json": "TEXT",
262+
"uuid": "VARCHAR(36)",
263+
"boolean": "INTEGER",
264+
"timestamp": "TIMESTAMP",
265+
"text": "TEXT",
266+
"blob": "BLOB",
267+
}
268+
269+
return type_map.get(type_category, "TEXT")
270+
271+
def list_available_features(self) -> "list[str]":
272+
"""List available feature flags across all supported dialects.
273+
274+
Returns:
275+
List of supported feature names
276+
"""
277+
return [
278+
"supports_json",
279+
"supports_jsonb",
280+
"supports_uuid",
281+
"supports_arrays",
282+
"supports_structs",
283+
"supports_returning",
284+
"supports_upsert",
285+
"supports_window_functions",
286+
"supports_cte",
287+
"supports_transactions",
288+
"supports_prepared_statements",
289+
"supports_schemas",
290+
]

sqlspec/adapters/adbc/driver.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sqlspec.adapters.adbc._types import AdbcConnection
2929
from sqlspec.core.result import SQLResult
3030
from sqlspec.driver import ExecutionResult
31+
from sqlspec.driver._sync import SyncDataDictionaryBase
3132

3233
__all__ = ("AdbcCursor", "AdbcDriver", "AdbcExceptionHandler", "get_adbc_statement_config")
3334

@@ -393,7 +394,7 @@ class AdbcDriver(SyncDriverAdapterBase):
393394
database dialects, parameter style conversion, and transaction management.
394395
"""
395396

396-
__slots__ = ("_detected_dialect", "dialect")
397+
__slots__ = ("_data_dictionary", "_detected_dialect", "dialect")
397398

398399
def __init__(
399400
self,
@@ -412,6 +413,7 @@ def __init__(
412413

413414
super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features)
414415
self.dialect = statement_config.dialect
416+
self._data_dictionary: Optional[SyncDataDictionaryBase] = None
415417

416418
@staticmethod
417419
def _ensure_pyarrow_installed() -> None:
@@ -654,3 +656,16 @@ def commit(self) -> None:
654656
except Exception as e:
655657
msg = f"Failed to commit transaction: {e}"
656658
raise SQLSpecError(msg) from e
659+
660+
@property
661+
def data_dictionary(self) -> "SyncDataDictionaryBase":
662+
"""Get the data dictionary for this driver.
663+
664+
Returns:
665+
Data dictionary instance for metadata queries
666+
"""
667+
if self._data_dictionary is None:
668+
from sqlspec.adapters.adbc.data_dictionary import AdbcDataDictionary
669+
670+
self._data_dictionary = AdbcDataDictionary()
671+
return self._data_dictionary

0 commit comments

Comments
 (0)