Skip to content

Commit 3656552

Browse files
authored
fix: StatementFilter and parameter validation fix (#34)
This corrects an issue where the parameters were being incorrectly validated when also using Statement Filters.
1 parent f820d71 commit 3656552

File tree

16 files changed

+496
-815
lines changed

16 files changed

+496
-815
lines changed

docs/examples/litestar_asyncpg.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,26 @@
1313
# ]
1414
# ///
1515

16+
from typing import Annotated, Optional
17+
1618
from litestar import Litestar, get
19+
from litestar.params import Dependency
1720

1821
from sqlspec.adapters.asyncpg import AsyncpgConfig, AsyncpgDriver, AsyncpgPoolConfig
19-
from sqlspec.extensions.litestar import DatabaseConfig, SQLSpec
22+
from sqlspec.extensions.litestar import DatabaseConfig, SQLSpec, providers
23+
from sqlspec.filters import FilterTypes
2024

2125

22-
@get("/")
23-
async def simple_asyncpg(db_session: AsyncpgDriver) -> dict[str, str]:
24-
return await db_session.select_one("SELECT 'Hello, world!' AS greeting")
26+
@get(
27+
"/",
28+
dependencies=providers.create_filter_dependencies({"search": "greeting", "search_ignore_case": True}),
29+
)
30+
async def simple_asyncpg(
31+
db_session: AsyncpgDriver, filters: Annotated[list[FilterTypes], Dependency(skip_validation=True)]
32+
) -> Optional[dict[str, str]]:
33+
return await db_session.select_one_or_none(
34+
"SELECT greeting FROM (select 'Hello, world!' as greeting) as t", *filters
35+
)
2536

2637

2738
sqlspec = SQLSpec(

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ nanoid = ["fastnanoid>=0.4.1"]
3030
oracledb = ["oracledb"]
3131
orjson = ["orjson"]
3232
performance = ["sqlglot[rs]", "msgspec"]
33+
polars = ["polars", "pyarrow"]
3334
psqlpy = ["psqlpy"]
3435
psycopg = ["psycopg[binary,pool]"]
3536
pydantic = ["pydantic", "pydantic-extra-types"]

sqlspec/adapters/adbc/driver.py

Lines changed: 37 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import contextlib
22
import logging
33
import re
4-
from collections.abc import Generator, Sequence
4+
from collections.abc import Generator, Mapping, Sequence
55
from contextlib import contextmanager
66
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union, cast, overload
77

88
from adbc_driver_manager.dbapi import Connection, Cursor
99
from sqlglot import exp as sqlglot_exp
1010

1111
from sqlspec.base import SyncDriverAdapterProtocol
12-
from sqlspec.exceptions import ParameterStyleMismatchError, SQLParsingError
12+
from sqlspec.exceptions import SQLParsingError
1313
from sqlspec.filters import StatementFilter
1414
from sqlspec.mixins import ResultConverter, SQLTranslatorMixin, SyncArrowBulkOperationsMixin
1515
from sqlspec.statement import SQLStatement
@@ -91,7 +91,6 @@ def _process_sql_params( # noqa: C901, PLR0912, PLR0915
9191
self,
9292
sql: str,
9393
parameters: "Optional[StatementParameterType]" = None,
94-
/,
9594
*filters: "StatementFilter",
9695
**kwargs: Any,
9796
) -> "tuple[str, Optional[tuple[Any, ...]]]": # Always returns tuple or None for params
@@ -108,14 +107,24 @@ def _process_sql_params( # noqa: C901, PLR0912, PLR0915
108107
**kwargs: Additional keyword arguments.
109108
110109
Raises:
111-
ParameterStyleMismatchError: If positional parameters are mixed with keyword arguments.
112110
SQLParsingError: If the SQL statement cannot be parsed.
113111
114112
Returns:
115113
A tuple of (sql, parameters) ready for execution.
116114
"""
115+
passed_parameters: Optional[Union[Mapping[str, Any], Sequence[Any]]] = None
116+
combined_filters_list: list[StatementFilter] = list(filters)
117+
118+
if parameters is not None:
119+
if isinstance(parameters, StatementFilter):
120+
combined_filters_list.insert(0, parameters)
121+
# passed_parameters remains None
122+
else:
123+
# If parameters is not a StatementFilter, it's actual data parameters.
124+
passed_parameters = parameters
125+
117126
# Special handling for SQLite with non-dict parameters and named placeholders
118-
if self.dialect == "sqlite" and parameters is not None and not is_dict(parameters):
127+
if self.dialect == "sqlite" and passed_parameters is not None and not is_dict(passed_parameters):
119128
# First mask out comments and strings to avoid detecting parameters in those
120129
comments = list(SQL_COMMENT_PATTERN.finditer(sql))
121130
strings = list(SQL_STRING_PATTERN.finditer(sql))
@@ -136,26 +145,15 @@ def _process_sql_params( # noqa: C901, PLR0912, PLR0915
136145
param_positions.sort(reverse=True)
137146
for start, end in param_positions:
138147
sql = sql[:start] + "?" + sql[end:]
139-
if not isinstance(parameters, (list, tuple)):
140-
return sql, (parameters,)
141-
return sql, tuple(parameters)
148+
if not isinstance(passed_parameters, (list, tuple)):
149+
passed_parameters = (passed_parameters,)
150+
passed_parameters = tuple(passed_parameters)
142151

143152
# Standard processing for all other cases
144-
merged_params = parameters
145-
if kwargs:
146-
if is_dict(parameters):
147-
merged_params = {**parameters, **kwargs}
148-
elif parameters is not None:
149-
msg = "Cannot mix positional parameters with keyword arguments for adbc driver."
150-
raise ParameterStyleMismatchError(msg)
151-
else:
152-
merged_params = kwargs
153+
statement = SQLStatement(sql, passed_parameters, kwargs=kwargs, dialect=self.dialect)
153154

154-
# 2. Create SQLStatement with dialect and process
155-
statement = SQLStatement(sql, merged_params, dialect=self.dialect)
156-
157-
# Apply any filters
158-
for filter_obj in filters:
155+
# Apply any filters from combined_filters_list
156+
for filter_obj in combined_filters_list:
159157
statement = statement.apply_filter(filter_obj)
160158

161159
processed_sql, processed_params, parsed_expr = statement.process()
@@ -284,7 +282,6 @@ def select(
284282
self,
285283
sql: str,
286284
parameters: "Optional[StatementParameterType]" = None,
287-
/,
288285
*filters: "StatementFilter",
289286
connection: "Optional[AdbcConnection]" = None,
290287
schema_type: None = None,
@@ -295,7 +292,6 @@ def select(
295292
self,
296293
sql: str,
297294
parameters: "Optional[StatementParameterType]" = None,
298-
/,
299295
*filters: "StatementFilter",
300296
connection: "Optional[AdbcConnection]" = None,
301297
schema_type: "type[ModelDTOT]",
@@ -305,7 +301,6 @@ def select(
305301
self,
306302
sql: str,
307303
parameters: Optional["StatementParameterType"] = None,
308-
/,
309304
*filters: "StatementFilter",
310305
connection: Optional["AdbcConnection"] = None,
311306
schema_type: "Optional[type[ModelDTOT]]" = None,
@@ -341,7 +336,6 @@ def select_one(
341336
self,
342337
sql: str,
343338
parameters: "Optional[StatementParameterType]" = None,
344-
/,
345339
*filters: "StatementFilter",
346340
connection: "Optional[AdbcConnection]" = None,
347341
schema_type: None = None,
@@ -352,7 +346,6 @@ def select_one(
352346
self,
353347
sql: str,
354348
parameters: "Optional[StatementParameterType]" = None,
355-
/,
356349
*filters: "StatementFilter",
357350
connection: "Optional[AdbcConnection]" = None,
358351
schema_type: "type[ModelDTOT]",
@@ -362,7 +355,6 @@ def select_one(
362355
self,
363356
sql: str,
364357
parameters: "Optional[StatementParameterType]" = None,
365-
/,
366358
*filters: "StatementFilter",
367359
connection: "Optional[AdbcConnection]" = None,
368360
schema_type: "Optional[type[ModelDTOT]]" = None,
@@ -396,7 +388,6 @@ def select_one_or_none(
396388
self,
397389
sql: str,
398390
parameters: "Optional[StatementParameterType]" = None,
399-
/,
400391
*filters: "StatementFilter",
401392
connection: "Optional[AdbcConnection]" = None,
402393
schema_type: None = None,
@@ -407,7 +398,6 @@ def select_one_or_none(
407398
self,
408399
sql: str,
409400
parameters: "Optional[StatementParameterType]" = None,
410-
/,
411401
*filters: "StatementFilter",
412402
connection: "Optional[AdbcConnection]" = None,
413403
schema_type: "type[ModelDTOT]",
@@ -417,7 +407,6 @@ def select_one_or_none(
417407
self,
418408
sql: str,
419409
parameters: Optional["StatementParameterType"] = None,
420-
/,
421410
*filters: "StatementFilter",
422411
connection: Optional["AdbcConnection"] = None,
423412
schema_type: "Optional[type[ModelDTOT]]" = None,
@@ -452,8 +441,7 @@ def select_value(
452441
self,
453442
sql: str,
454443
parameters: "Optional[StatementParameterType]" = None,
455-
/,
456-
*filters: StatementFilter,
444+
*filters: "StatementFilter",
457445
connection: "Optional[AdbcConnection]" = None,
458446
schema_type: None = None,
459447
**kwargs: Any,
@@ -463,8 +451,7 @@ def select_value(
463451
self,
464452
sql: str,
465453
parameters: "Optional[StatementParameterType]" = None,
466-
/,
467-
*filters: StatementFilter,
454+
*filters: "StatementFilter",
468455
connection: "Optional[AdbcConnection]" = None,
469456
schema_type: "type[T]",
470457
**kwargs: Any,
@@ -473,8 +460,7 @@ def select_value(
473460
self,
474461
sql: str,
475462
parameters: "Optional[StatementParameterType]" = None,
476-
/,
477-
*filters: StatementFilter,
463+
*filters: "StatementFilter",
478464
connection: "Optional[AdbcConnection]" = None,
479465
schema_type: "Optional[type[T]]" = None,
480466
**kwargs: Any,
@@ -508,8 +494,7 @@ def select_value_or_none(
508494
self,
509495
sql: str,
510496
parameters: "Optional[StatementParameterType]" = None,
511-
/,
512-
*filters: StatementFilter,
497+
*filters: "StatementFilter",
513498
connection: "Optional[AdbcConnection]" = None,
514499
schema_type: None = None,
515500
**kwargs: Any,
@@ -519,8 +504,7 @@ def select_value_or_none(
519504
self,
520505
sql: str,
521506
parameters: "Optional[StatementParameterType]" = None,
522-
/,
523-
*filters: StatementFilter,
507+
*filters: "StatementFilter",
524508
connection: "Optional[AdbcConnection]" = None,
525509
schema_type: "type[T]",
526510
**kwargs: Any,
@@ -529,8 +513,7 @@ def select_value_or_none(
529513
self,
530514
sql: str,
531515
parameters: "Optional[StatementParameterType]" = None,
532-
/,
533-
*filters: StatementFilter,
516+
*filters: "StatementFilter",
534517
connection: "Optional[AdbcConnection]" = None,
535518
schema_type: "Optional[type[T]]" = None,
536519
**kwargs: Any,
@@ -564,22 +547,21 @@ def insert_update_delete(
564547
self,
565548
sql: str,
566549
parameters: "Optional[StatementParameterType]" = None,
567-
/,
568550
*filters: "StatementFilter",
569551
connection: "Optional[AdbcConnection]" = None,
570552
**kwargs: Any,
571553
) -> int:
572554
"""Execute an insert, update, or delete statement.
573555
574556
Args:
575-
sql: The SQL statement to execute.
557+
sql: The SQL statement string.
576558
parameters: The parameters for the statement (dict, tuple, list, or None).
577559
*filters: Statement filters to apply.
578560
connection: Optional connection override.
579561
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
580562
581563
Returns:
582-
The number of rows affected by the statement.
564+
Row count affected by the operation.
583565
"""
584566
connection = self._connection(connection)
585567
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
@@ -593,8 +575,7 @@ def insert_update_delete_returning(
593575
self,
594576
sql: str,
595577
parameters: "Optional[StatementParameterType]" = None,
596-
/,
597-
*filters: StatementFilter,
578+
*filters: "StatementFilter",
598579
connection: "Optional[AdbcConnection]" = None,
599580
schema_type: None = None,
600581
**kwargs: Any,
@@ -604,8 +585,7 @@ def insert_update_delete_returning(
604585
self,
605586
sql: str,
606587
parameters: "Optional[StatementParameterType]" = None,
607-
/,
608-
*filters: StatementFilter,
588+
*filters: "StatementFilter",
609589
connection: "Optional[AdbcConnection]" = None,
610590
schema_type: "type[ModelDTOT]",
611591
**kwargs: Any,
@@ -614,24 +594,23 @@ def insert_update_delete_returning(
614594
self,
615595
sql: str,
616596
parameters: "Optional[StatementParameterType]" = None,
617-
/,
618-
*filters: StatementFilter,
597+
*filters: "StatementFilter",
619598
connection: "Optional[AdbcConnection]" = None,
620599
schema_type: "Optional[type[ModelDTOT]]" = None,
621600
**kwargs: Any,
622601
) -> "Optional[Union[dict[str, Any], ModelDTOT]]":
623-
"""Insert, update, or delete data from the database and return result.
602+
"""Insert, update, or delete data with RETURNING clause.
624603
625604
Args:
626-
sql: The SQL statement to execute.
605+
sql: The SQL statement string.
627606
parameters: The parameters for the statement (dict, tuple, list, or None).
628607
*filters: Statement filters to apply.
629608
connection: Optional connection override.
630609
schema_type: Optional schema class for the result.
631610
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
632611
633612
Returns:
634-
The first row of results.
613+
The returned row data, or None if no row returned.
635614
"""
636615
connection = self._connection(connection)
637616
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)
@@ -648,7 +627,6 @@ def execute_script(
648627
self,
649628
sql: str,
650629
parameters: "Optional[StatementParameterType]" = None,
651-
/,
652630
connection: "Optional[AdbcConnection]" = None,
653631
**kwargs: Any,
654632
) -> str:
@@ -673,12 +651,11 @@ def execute_script(
673651

674652
# --- Arrow Bulk Operations ---
675653

676-
def select_arrow( # pyright: ignore[reportUnknownParameterType]
654+
def select_arrow(
677655
self,
678656
sql: str,
679657
parameters: "Optional[StatementParameterType]" = None,
680-
/,
681-
*filters: StatementFilter,
658+
*filters: "StatementFilter",
682659
connection: "Optional[AdbcConnection]" = None,
683660
**kwargs: Any,
684661
) -> "ArrowTable": # pyright: ignore[reportUnknownVariableType]
@@ -692,7 +669,7 @@ def select_arrow( # pyright: ignore[reportUnknownParameterType]
692669
**kwargs: Additional keyword arguments to merge with parameters if parameters is a dict.
693670
694671
Returns:
695-
An Apache Arrow Table containing the query results.
672+
An Arrow Table containing the query results.
696673
"""
697674
connection = self._connection(connection)
698675
sql, parameters = self._process_sql_params(sql, parameters, *filters, **kwargs)

0 commit comments

Comments
 (0)