7
7
from sqlspec .utils .logging import get_logger
8
8
9
9
if TYPE_CHECKING :
10
+ from collections .abc import Callable
11
+
10
12
from sqlspec .adapters .adbc .driver import AdbcDriver
11
13
12
14
logger = get_logger ("adapters.adbc.data_dictionary" )
22
24
class AdbcDataDictionary (SyncDataDictionaryBase ):
23
25
"""ADBC multi-dialect data dictionary.
24
26
25
- Detects the underlying database dialect and delegates to appropriate logic .
27
+ Delegates to appropriate dialect-specific logic based on the driver's dialect .
26
28
"""
27
29
28
- def __init__ (self ) -> None :
29
- """Initialize ADBC data dictionary."""
30
- self ._detected_dialect : Optional [str ] = None
31
- self ._cached_version : Optional [VersionInfo ] = None
32
-
33
- def _detect_dialect (self , driver : SyncDriverAdapterBase ) -> str :
34
- """Detect the underlying database dialect.
30
+ def _get_dialect (self , driver : SyncDriverAdapterBase ) -> str :
31
+ """Get dialect from ADBC driver.
35
32
36
33
Args:
37
34
driver: ADBC driver instance
38
35
39
36
Returns:
40
- Detected dialect name
37
+ Dialect name
41
38
"""
42
- if self ._detected_dialect :
43
- return self ._detected_dialect
44
-
45
- self ._detected_dialect = (
46
- str (cast ("AdbcDriver" , driver ).dialect ) if cast ("AdbcDriver" , driver ).dialect else "sqlite"
47
- )
48
- return self ._detected_dialect
39
+ return str (cast ("AdbcDriver" , driver ).dialect )
49
40
50
41
def get_version (self , driver : SyncDriverAdapterBase ) -> "Optional[VersionInfo]" :
51
42
"""Get database version information based on detected dialect.
@@ -56,59 +47,51 @@ def get_version(self, driver: SyncDriverAdapterBase) -> "Optional[VersionInfo]":
56
47
Returns:
57
48
Database version information or None if detection fails
58
49
"""
59
- if self ._cached_version :
60
- return self ._cached_version
61
-
62
- dialect = self ._detect_dialect (driver )
63
-
64
- if dialect == "postgres" :
65
- try :
66
- version_str = cast ("AdbcDriver" , driver ).select_value ("SELECT version()" )
67
- version_match = POSTGRES_VERSION_PATTERN .search (str (version_str ))
68
- if version_match :
69
- major = int (version_match .group (1 ))
70
- minor = int (version_match .group (2 ))
71
- patch = int (version_match .group (3 )) if version_match .group (3 ) else 0
72
- self ._cached_version = VersionInfo (major , minor , patch )
73
- except Exception :
74
- logger .warning ("Failed to get PostgreSQL version" )
75
-
76
- elif dialect == "sqlite" :
77
- try :
78
- version_str = cast ("AdbcDriver" , driver ).select_value ("SELECT sqlite_version()" )
79
- version_match = SQLITE_VERSION_PATTERN .match (str (version_str ))
80
- if version_match :
81
- major , minor , patch = map (int , version_match .groups ())
82
- self ._cached_version = VersionInfo (major , minor , patch )
83
- except Exception :
84
- logger .warning ("Failed to get SQLite version" )
85
-
86
- elif dialect == "duckdb" :
87
- try :
88
- version_str = cast ("AdbcDriver" , driver ).select_value ("SELECT version()" )
89
- version_match = DUCKDB_VERSION_PATTERN .search (str (version_str ))
90
- if version_match :
91
- major , minor , patch = map (int , version_match .groups ())
92
- self ._cached_version = VersionInfo (major , minor , patch )
93
- except Exception :
94
- logger .warning ("Failed to get DuckDB version" )
95
-
96
- elif dialect == "mysql" :
97
- try :
98
- version_str = cast ("AdbcDriver" , driver ).select_value ("SELECT VERSION()" )
99
- version_match = MYSQL_VERSION_PATTERN .search (str (version_str ))
100
- if version_match :
101
- major , minor , patch = map (int , version_match .groups ())
102
- self ._cached_version = VersionInfo (major , minor , patch )
103
- except Exception :
104
- logger .warning ("Failed to get MySQL version" )
105
-
106
- elif dialect == "bigquery" :
107
- # BigQuery is a cloud service
108
- self ._cached_version = VersionInfo (1 , 0 , 0 )
109
-
110
- logger .debug ("Detected %s version: %s" , dialect , self ._cached_version )
111
- return self ._cached_version
50
+ dialect = self ._get_dialect (driver )
51
+ adbc_driver = cast ("AdbcDriver" , driver )
52
+
53
+ try :
54
+ if dialect == "postgres" :
55
+ version_str = adbc_driver .select_value ("SELECT version()" )
56
+ if version_str :
57
+ match = POSTGRES_VERSION_PATTERN .search (str (version_str ))
58
+ if match :
59
+ major = int (match .group (1 ))
60
+ minor = int (match .group (2 ))
61
+ patch = int (match .group (3 )) if match .group (3 ) else 0
62
+ return VersionInfo (major , minor , patch )
63
+
64
+ elif dialect == "sqlite" :
65
+ version_str = adbc_driver .select_value ("SELECT sqlite_version()" )
66
+ if version_str :
67
+ match = SQLITE_VERSION_PATTERN .match (str (version_str ))
68
+ if match :
69
+ major , minor , patch = map (int , match .groups ())
70
+ return VersionInfo (major , minor , patch )
71
+
72
+ elif dialect == "duckdb" :
73
+ version_str = adbc_driver .select_value ("SELECT version()" )
74
+ if version_str :
75
+ match = DUCKDB_VERSION_PATTERN .search (str (version_str ))
76
+ if match :
77
+ major , minor , patch = map (int , match .groups ())
78
+ return VersionInfo (major , minor , patch )
79
+
80
+ elif dialect == "mysql" :
81
+ version_str = adbc_driver .select_value ("SELECT VERSION()" )
82
+ if version_str :
83
+ match = MYSQL_VERSION_PATTERN .search (str (version_str ))
84
+ if match :
85
+ major , minor , patch = map (int , match .groups ())
86
+ return VersionInfo (major , minor , patch )
87
+
88
+ elif dialect == "bigquery" :
89
+ return VersionInfo (1 , 0 , 0 )
90
+
91
+ except Exception :
92
+ logger .warning ("Failed to get %s version" , dialect )
93
+
94
+ return None
112
95
113
96
def get_feature_flag (self , driver : SyncDriverAdapterBase , feature : str ) -> bool :
114
97
"""Check if database supports a specific feature based on detected dialect.
@@ -120,25 +103,35 @@ def get_feature_flag(self, driver: SyncDriverAdapterBase, feature: str) -> bool:
120
103
Returns:
121
104
True if feature is supported, False otherwise
122
105
"""
123
- dialect = self ._detect_dialect (driver )
106
+ dialect = self ._get_dialect (driver )
124
107
version_info = self .get_version (driver )
125
108
126
109
if dialect == "postgres" :
127
- feature_checks = {
110
+ feature_checks : dict [ str , Callable [..., bool ]] = {
128
111
"supports_json" : lambda v : v and v >= VersionInfo (9 , 2 , 0 ),
129
112
"supports_jsonb" : lambda v : v and v >= VersionInfo (9 , 4 , 0 ),
130
113
"supports_uuid" : lambda _ : True ,
131
114
"supports_arrays" : lambda _ : True ,
132
115
"supports_returning" : lambda v : v and v >= VersionInfo (8 , 2 , 0 ),
133
116
"supports_upsert" : lambda v : v and v >= VersionInfo (9 , 5 , 0 ),
117
+ "supports_window_functions" : lambda v : v and v >= VersionInfo (8 , 4 , 0 ),
118
+ "supports_cte" : lambda v : v and v >= VersionInfo (8 , 4 , 0 ),
119
+ "supports_transactions" : lambda _ : True ,
120
+ "supports_prepared_statements" : lambda _ : True ,
121
+ "supports_schemas" : lambda _ : True ,
134
122
}
135
123
elif dialect == "sqlite" :
136
124
feature_checks = {
137
125
"supports_json" : lambda v : v and v >= VersionInfo (3 , 38 , 0 ),
138
126
"supports_returning" : lambda v : v and v >= VersionInfo (3 , 35 , 0 ),
139
127
"supports_upsert" : lambda v : v and v >= VersionInfo (3 , 24 , 0 ),
140
- "supports_uuid" : lambda _ : False ,
128
+ "supports_window_functions" : lambda v : v and v >= VersionInfo (3 , 25 , 0 ),
129
+ "supports_cte" : lambda v : v and v >= VersionInfo (3 , 8 , 3 ),
130
+ "supports_transactions" : lambda _ : True ,
131
+ "supports_prepared_statements" : lambda _ : True ,
132
+ "supports_schemas" : lambda _ : False ,
141
133
"supports_arrays" : lambda _ : False ,
134
+ "supports_uuid" : lambda _ : False ,
142
135
}
143
136
elif dialect == "duckdb" :
144
137
feature_checks = {
@@ -147,13 +140,22 @@ def get_feature_flag(self, driver: SyncDriverAdapterBase, feature: str) -> bool:
147
140
"supports_uuid" : lambda _ : True ,
148
141
"supports_returning" : lambda v : v and v >= VersionInfo (0 , 8 , 0 ),
149
142
"supports_upsert" : lambda v : v and v >= VersionInfo (0 , 8 , 0 ),
143
+ "supports_window_functions" : lambda _ : True ,
144
+ "supports_cte" : lambda _ : True ,
145
+ "supports_transactions" : lambda _ : True ,
146
+ "supports_prepared_statements" : lambda _ : True ,
147
+ "supports_schemas" : lambda _ : True ,
150
148
}
151
149
elif dialect == "mysql" :
152
150
feature_checks = {
153
151
"supports_json" : lambda v : v and v >= VersionInfo (5 , 7 , 8 ),
154
152
"supports_cte" : lambda v : v and v >= VersionInfo (8 , 0 , 1 ),
155
153
"supports_returning" : lambda _ : False ,
156
154
"supports_upsert" : lambda _ : True ,
155
+ "supports_window_functions" : lambda v : v and v >= VersionInfo (8 , 0 , 2 ),
156
+ "supports_transactions" : lambda _ : True ,
157
+ "supports_prepared_statements" : lambda _ : True ,
158
+ "supports_schemas" : lambda _ : True ,
157
159
"supports_uuid" : lambda _ : False ,
158
160
"supports_arrays" : lambda _ : False ,
159
161
}
@@ -164,23 +166,23 @@ def get_feature_flag(self, driver: SyncDriverAdapterBase, feature: str) -> bool:
164
166
"supports_structs" : lambda _ : True ,
165
167
"supports_returning" : lambda _ : False ,
166
168
"supports_upsert" : lambda _ : True ,
169
+ "supports_window_functions" : lambda _ : True ,
170
+ "supports_cte" : lambda _ : True ,
171
+ "supports_transactions" : lambda _ : False ,
172
+ "supports_prepared_statements" : lambda _ : True ,
173
+ "supports_schemas" : lambda _ : True ,
167
174
"supports_uuid" : lambda _ : False ,
168
175
}
169
176
else :
170
- feature_checks = {}
171
-
172
- # Common features
173
- common_features = {
174
- "supports_transactions" : lambda _ : True ,
175
- "supports_prepared_statements" : lambda _ : True ,
176
- "supports_window_functions" : lambda _ : True ,
177
- "supports_cte" : lambda _ : True ,
178
- }
179
-
180
- feature_checks .update (common_features )
177
+ feature_checks = {
178
+ "supports_transactions" : lambda _ : True ,
179
+ "supports_prepared_statements" : lambda _ : True ,
180
+ "supports_window_functions" : lambda _ : True ,
181
+ "supports_cte" : lambda _ : True ,
182
+ }
181
183
182
184
if feature in feature_checks :
183
- return feature_checks [feature ](version_info )
185
+ return bool ( feature_checks [feature ](version_info ) )
184
186
185
187
return False
186
188
@@ -194,7 +196,7 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) ->
194
196
Returns:
195
197
Database-specific type name
196
198
"""
197
- dialect = self ._detect_dialect (driver )
199
+ dialect = self ._get_dialect (driver )
198
200
version_info = self .get_version (driver )
199
201
200
202
if dialect == "postgres" :
@@ -210,6 +212,7 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) ->
210
212
"timestamp" : "TIMESTAMP WITH TIME ZONE" ,
211
213
"text" : "TEXT" ,
212
214
"blob" : "BYTEA" ,
215
+ "array" : "ARRAY" ,
213
216
}
214
217
215
218
elif dialect == "sqlite" :
@@ -254,7 +257,6 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) ->
254
257
"array" : "ARRAY" ,
255
258
}
256
259
else :
257
- # Generic fallback
258
260
type_map = {
259
261
"json" : "TEXT" ,
260
262
"uuid" : "VARCHAR(36)" ,
0 commit comments