Skip to content

Commit 7961e6c

Browse files
committed
fix: converts filters to be mypyc compatible
1 parent 375da8b commit 7961e6c

File tree

2 files changed

+166
-46
lines changed

2 files changed

+166
-46
lines changed

sqlspec/core/filters.py

Lines changed: 136 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from abc import ABC, abstractmethod
2323
from collections import abc
2424
from collections.abc import Sequence
25-
from dataclasses import dataclass
2625
from datetime import datetime
2726
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union
2827

@@ -74,7 +73,7 @@ def append_to_statement(self, statement: "SQL") -> "SQL":
7473
Parameters should be provided via extract_parameters().
7574
"""
7675

77-
def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
76+
def extract_parameters(self) -> "tuple[list[Any], dict[str, Any]]":
7877
"""Extract parameters that this filter contributes.
7978
8079
Returns:
@@ -118,16 +117,30 @@ def get_cache_key(self) -> tuple[Any, ...]:
118117
"""
119118

120119

121-
@dataclass(frozen=True)
122120
class BeforeAfterFilter(StatementFilter):
123121
"""Filter for datetime range queries.
124122
125123
Applies WHERE clauses for before/after datetime filtering.
126124
"""
127125

128-
field_name: str
129-
before: Optional[datetime] = None
130-
after: Optional[datetime] = None
126+
__slots__ = ("_after", "_before", "_field_name")
127+
128+
def __init__(self, field_name: str, before: Optional[datetime] = None, after: Optional[datetime] = None) -> None:
129+
self._field_name = field_name
130+
self._before = before
131+
self._after = after
132+
133+
@property
134+
def field_name(self) -> str:
135+
return self._field_name
136+
137+
@property
138+
def before(self) -> Optional[datetime]:
139+
return self._before
140+
141+
@property
142+
def after(self) -> Optional[datetime]:
143+
return self._after
131144

132145
def get_param_names(self) -> list[str]:
133146
"""Get parameter names without storing them."""
@@ -184,16 +197,32 @@ def get_cache_key(self) -> tuple[Any, ...]:
184197
return ("BeforeAfterFilter", self.field_name, self.before, self.after)
185198

186199

187-
@dataclass(frozen=True)
188200
class OnBeforeAfterFilter(StatementFilter):
189201
"""Filter for inclusive datetime range queries.
190202
191203
Applies WHERE clauses for on-or-before/on-or-after datetime filtering.
192204
"""
193205

194-
field_name: str
195-
on_or_before: Optional[datetime] = None
196-
on_or_after: Optional[datetime] = None
206+
__slots__ = ("_field_name", "_on_or_after", "_on_or_before")
207+
208+
def __init__(
209+
self, field_name: str, on_or_before: Optional[datetime] = None, on_or_after: Optional[datetime] = None
210+
) -> None:
211+
self._field_name = field_name
212+
self._on_or_before = on_or_before
213+
self._on_or_after = on_or_after
214+
215+
@property
216+
def field_name(self) -> str:
217+
return self._field_name
218+
219+
@property
220+
def on_or_before(self) -> Optional[datetime]:
221+
return self._on_or_before
222+
223+
@property
224+
def on_or_after(self) -> Optional[datetime]:
225+
return self._on_or_after
197226

198227
def get_param_names(self) -> list[str]:
199228
"""Get parameter names without storing them."""
@@ -261,15 +290,25 @@ def append_to_statement(self, statement: "SQL") -> "SQL":
261290
raise NotImplementedError
262291

263292

264-
@dataclass(frozen=True)
265293
class InCollectionFilter(InAnyFilter[T]):
266294
"""Filter for IN clause queries.
267295
268296
Constructs WHERE ... IN (...) clauses.
269297
"""
270298

271-
field_name: str
272-
values: Optional[abc.Collection[T]] = None
299+
__slots__ = ("_field_name", "_values")
300+
301+
def __init__(self, field_name: str, values: Optional[abc.Collection[T]] = None) -> None:
302+
self._field_name = field_name
303+
self._values = values
304+
305+
@property
306+
def field_name(self) -> str:
307+
return self._field_name
308+
309+
@property
310+
def values(self) -> Optional[abc.Collection[T]]:
311+
return self._values
273312

274313
def get_param_names(self) -> list[str]:
275314
"""Get parameter names without storing them."""
@@ -311,15 +350,25 @@ def get_cache_key(self) -> tuple[Any, ...]:
311350
return ("InCollectionFilter", self.field_name, values_tuple)
312351

313352

314-
@dataclass(frozen=True)
315353
class NotInCollectionFilter(InAnyFilter[T]):
316354
"""Filter for NOT IN clause queries.
317355
318356
Constructs WHERE ... NOT IN (...) clauses.
319357
"""
320358

321-
field_name: str
322-
values: Optional[abc.Collection[T]] = None
359+
__slots__ = ("_field_name", "_values")
360+
361+
def __init__(self, field_name: str, values: Optional[abc.Collection[T]] = None) -> None:
362+
self._field_name = field_name
363+
self._values = values
364+
365+
@property
366+
def field_name(self) -> str:
367+
return self._field_name
368+
369+
@property
370+
def values(self) -> Optional[abc.Collection[T]]:
371+
return self._values
323372

324373
def get_param_names(self) -> list[str]:
325374
"""Get parameter names without storing them."""
@@ -361,15 +410,25 @@ def get_cache_key(self) -> tuple[Any, ...]:
361410
return ("NotInCollectionFilter", self.field_name, values_tuple)
362411

363412

364-
@dataclass(frozen=True)
365413
class AnyCollectionFilter(InAnyFilter[T]):
366414
"""Filter for PostgreSQL-style ANY clause queries.
367415
368416
Constructs WHERE column_name = ANY (array_expression) clauses.
369417
"""
370418

371-
field_name: str
372-
values: Optional[abc.Collection[T]] = None
419+
__slots__ = ("_field_name", "_values")
420+
421+
def __init__(self, field_name: str, values: Optional[abc.Collection[T]] = None) -> None:
422+
self._field_name = field_name
423+
self._values = values
424+
425+
@property
426+
def field_name(self) -> str:
427+
return self._field_name
428+
429+
@property
430+
def values(self) -> Optional[abc.Collection[T]]:
431+
return self._values
373432

374433
def get_param_names(self) -> list[str]:
375434
"""Get parameter names without storing them."""
@@ -412,15 +471,25 @@ def get_cache_key(self) -> tuple[Any, ...]:
412471
return ("AnyCollectionFilter", self.field_name, values_tuple)
413472

414473

415-
@dataclass(frozen=True)
416474
class NotAnyCollectionFilter(InAnyFilter[T]):
417475
"""Filter for PostgreSQL-style NOT ANY clause queries.
418476
419477
Constructs WHERE NOT (column_name = ANY (array_expression)) clauses.
420478
"""
421479

422-
field_name: str
423-
values: Optional[abc.Collection[T]] = None
480+
__slots__ = ("_field_name", "_values")
481+
482+
def __init__(self, field_name: str, values: Optional[abc.Collection[T]] = None) -> None:
483+
self._field_name = field_name
484+
self._values = values
485+
486+
@property
487+
def field_name(self) -> str:
488+
return self._field_name
489+
490+
@property
491+
def values(self) -> Optional[abc.Collection[T]]:
492+
return self._values
424493

425494
def get_param_names(self) -> list[str]:
426495
"""Get parameter names without storing them."""
@@ -471,15 +540,25 @@ def append_to_statement(self, statement: "SQL") -> "SQL":
471540
raise NotImplementedError
472541

473542

474-
@dataclass(frozen=True)
475543
class LimitOffsetFilter(PaginationFilter):
476544
"""Filter for LIMIT and OFFSET clauses.
477545
478546
Adds pagination support through LIMIT/OFFSET SQL clauses.
479547
"""
480548

481-
limit: int
482-
offset: int
549+
__slots__ = ("_limit", "_offset")
550+
551+
def __init__(self, limit: int, offset: int) -> None:
552+
self._limit = limit
553+
self._offset = offset
554+
555+
@property
556+
def limit(self) -> int:
557+
return self._limit
558+
559+
@property
560+
def offset(self) -> int:
561+
return self._offset
483562

484563
def get_param_names(self) -> list[str]:
485564
"""Get parameter names without storing them."""
@@ -517,15 +596,25 @@ def get_cache_key(self) -> tuple[Any, ...]:
517596
return ("LimitOffsetFilter", self.limit, self.offset)
518597

519598

520-
@dataclass(frozen=True)
521599
class OrderByFilter(StatementFilter):
522600
"""Filter for ORDER BY clauses.
523601
524602
Adds sorting capability to SQL queries.
525603
"""
526604

527-
field_name: str
528-
sort_order: Literal["asc", "desc"] = "asc"
605+
__slots__ = ("_field_name", "_sort_order")
606+
607+
def __init__(self, field_name: str, sort_order: Literal["asc", "desc"] = "asc") -> None:
608+
self._field_name = field_name
609+
self._sort_order = sort_order
610+
611+
@property
612+
def field_name(self) -> str:
613+
return self._field_name
614+
615+
@property
616+
def sort_order(self) -> Literal["asc", "desc"]:
617+
return self._sort_order # pyright: ignore
529618

530619
def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
531620
"""Extract filter parameters."""
@@ -553,16 +642,30 @@ def get_cache_key(self) -> tuple[Any, ...]:
553642
return ("OrderByFilter", self.field_name, self.sort_order)
554643

555644

556-
@dataclass(frozen=True)
557645
class SearchFilter(StatementFilter):
558646
"""Filter for text search queries.
559647
560648
Constructs WHERE field_name LIKE '%value%' clauses.
561649
"""
562650

563-
field_name: Union[str, set[str]]
564-
value: str
565-
ignore_case: Optional[bool] = False
651+
__slots__ = ("_field_name", "_ignore_case", "_value")
652+
653+
def __init__(self, field_name: Union[str, set[str]], value: str, ignore_case: Optional[bool] = False) -> None:
654+
self._field_name = field_name
655+
self._value = value
656+
self._ignore_case = ignore_case
657+
658+
@property
659+
def field_name(self) -> Union[str, set[str]]:
660+
return self._field_name
661+
662+
@property
663+
def value(self) -> str:
664+
return self._value
665+
666+
@property
667+
def ignore_case(self) -> Optional[bool]:
668+
return self._ignore_case
566669

567670
def get_param_name(self) -> Optional[str]:
568671
"""Get parameter name without storing it."""
@@ -617,7 +720,6 @@ def get_cache_key(self) -> tuple[Any, ...]:
617720
return ("SearchFilter", field_names, self.value, self.ignore_case)
618721

619722

620-
@dataclass(frozen=True)
621723
class NotInSearchFilter(SearchFilter):
622724
"""Filter for negated text search queries.
623725
@@ -732,7 +834,7 @@ def apply_filter(statement: "SQL", filter_obj: StatementFilter) -> "SQL":
732834
def create_filters(filters: "list[StatementFilter]") -> tuple["StatementFilter", ...]:
733835
"""Convert mutable filters to immutable tuple.
734836
735-
Since StatementFilter classes are now immutable (frozen dataclasses),
837+
Since StatementFilter classes are now immutable (with read-only properties),
736838
we just need to convert to a tuple for consistent sharing.
737839
738840
Args:

tests/unit/test_utils/test_type_guards.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55
"""
66

77
from dataclasses import dataclass
8-
from typing import TYPE_CHECKING, Any, Optional, cast
8+
from typing import Any, Optional, cast
99

1010
import msgspec
1111
import pytest
12-
13-
if TYPE_CHECKING:
14-
from sqlglot import exp
12+
from sqlglot import exp
1513

1614
from sqlspec.utils.type_guards import (
1715
dataclass_to_dict,
@@ -64,6 +62,8 @@
6462

6563
pytestmark = pytest.mark.xdist_group("utils")
6664

65+
_UNSET = object()
66+
6767

6868
@dataclass
6969
class SampleDataclass:
@@ -74,18 +74,36 @@ class SampleDataclass:
7474
optional_field: "Optional[str]" = None
7575

7676

77-
class MockSQLGlotExpression(exp.Expression):
78-
"""Mock SQLGlot expression for testing."""
77+
class MockSQLGlotExpression:
78+
"""Mock SQLGlot expression for testing type guard functions.
79+
80+
This mock allows us to test cases where attributes don't exist,
81+
which is needed to test the AttributeError handling in type guards.
82+
"""
7983

8084
def __init__(
8185
self,
82-
this: "Optional[Any]" = None,
83-
expressions: "Optional[list[Any]]" = None,
84-
parent: "Optional[Any]" = None,
86+
this: Any = _UNSET,
87+
expressions: Any = _UNSET,
88+
parent: Any = _UNSET,
8589
args: "Optional[dict[str, Any]]" = None,
8690
) -> None:
87-
# Call parent constructor with proper args
88-
super().__init__(this=this, expressions=expressions or [], parent=parent, **args or {})
91+
# Only set attributes if they were explicitly provided
92+
if this is not _UNSET:
93+
self.this = this
94+
if expressions is not _UNSET:
95+
self.expressions = expressions
96+
if parent is not _UNSET:
97+
self.parent = parent
98+
99+
# SQLGlot expressions always have an args dict
100+
self.args = args or {}
101+
102+
# Set any additional attributes from args
103+
if args:
104+
for key, value in args.items():
105+
if key not in {"this", "expressions", "parent"}:
106+
setattr(self, key, value)
89107

90108

91109
class MockLiteral:
@@ -606,7 +624,7 @@ def __init__(self) -> None:
606624
self.initial_expression = mock_expr
607625

608626
context = MockContext()
609-
assert get_initial_expression(context) is mock_expr
627+
assert get_initial_expression(context) is mock_expr # type: ignore[comparison-overlap]
610628

611629

612630
def test_get_initial_expression_without_attribute() -> None:

0 commit comments

Comments
 (0)