Skip to content

Commit 2c29893

Browse files
authored
fix: correctly handle named paramter values (#46)
Enhance the parameter naming strategy in the builder to ensure unique names based on column names or positional indices, improving clarity and preventing conflicts.
1 parent 4df58ec commit 2c29893

File tree

16 files changed

+1851
-132
lines changed

16 files changed

+1851
-132
lines changed

sqlspec/_sql.py

Lines changed: 252 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,65 @@
44
"""
55

66
import logging
7-
from typing import Any, Optional, Union
7+
from typing import TYPE_CHECKING, Any, Optional, Union
88

99
import sqlglot
1010
from sqlglot import exp
1111
from sqlglot.dialects.dialect import DialectType
1212
from sqlglot.errors import ParseError as SQLGlotParseError
1313

14-
from sqlspec.builder import Column, Delete, Insert, Merge, Select, Truncate, Update
14+
from sqlspec.builder import (
15+
AlterTable,
16+
Column,
17+
CommentOn,
18+
CreateIndex,
19+
CreateMaterializedView,
20+
CreateSchema,
21+
CreateTable,
22+
CreateTableAsSelect,
23+
CreateView,
24+
Delete,
25+
DropIndex,
26+
DropSchema,
27+
DropTable,
28+
DropView,
29+
Insert,
30+
Merge,
31+
RenameTable,
32+
Select,
33+
Truncate,
34+
Update,
35+
)
1536
from sqlspec.exceptions import SQLBuilderError
1637

17-
__all__ = ("Case", "Column", "Delete", "Insert", "Merge", "SQLFactory", "Select", "Truncate", "Update", "sql")
38+
if TYPE_CHECKING:
39+
from sqlspec.core.statement import SQL
40+
41+
__all__ = (
42+
"AlterTable",
43+
"Case",
44+
"Column",
45+
"CommentOn",
46+
"CreateIndex",
47+
"CreateMaterializedView",
48+
"CreateSchema",
49+
"CreateTable",
50+
"CreateTableAsSelect",
51+
"CreateView",
52+
"Delete",
53+
"DropIndex",
54+
"DropSchema",
55+
"DropTable",
56+
"DropView",
57+
"Insert",
58+
"Merge",
59+
"RenameTable",
60+
"SQLFactory",
61+
"Select",
62+
"Truncate",
63+
"Update",
64+
"sql",
65+
)
1866

1967
logger = logging.getLogger("sqlspec")
2068

@@ -212,6 +260,174 @@ def merge(self, table_or_sql: Optional[str] = None, dialect: DialectType = None)
212260
return builder.into(table_or_sql)
213261
return builder
214262

263+
# ===================
264+
# DDL Statement Builders
265+
# ===================
266+
267+
def create_table(self, table_name: str, dialect: DialectType = None) -> "CreateTable":
268+
"""Create a CREATE TABLE builder.
269+
270+
Args:
271+
table_name: Name of the table to create
272+
dialect: Optional SQL dialect
273+
274+
Returns:
275+
CreateTable builder instance
276+
"""
277+
builder = CreateTable(table_name)
278+
builder.dialect = dialect or self.dialect
279+
return builder
280+
281+
def create_table_as_select(self, dialect: DialectType = None) -> "CreateTableAsSelect":
282+
"""Create a CREATE TABLE AS SELECT builder.
283+
284+
Args:
285+
dialect: Optional SQL dialect
286+
287+
Returns:
288+
CreateTableAsSelect builder instance
289+
"""
290+
builder = CreateTableAsSelect()
291+
builder.dialect = dialect or self.dialect
292+
return builder
293+
294+
def create_view(self, dialect: DialectType = None) -> "CreateView":
295+
"""Create a CREATE VIEW builder.
296+
297+
Args:
298+
dialect: Optional SQL dialect
299+
300+
Returns:
301+
CreateView builder instance
302+
"""
303+
builder = CreateView()
304+
builder.dialect = dialect or self.dialect
305+
return builder
306+
307+
def create_materialized_view(self, dialect: DialectType = None) -> "CreateMaterializedView":
308+
"""Create a CREATE MATERIALIZED VIEW builder.
309+
310+
Args:
311+
dialect: Optional SQL dialect
312+
313+
Returns:
314+
CreateMaterializedView builder instance
315+
"""
316+
builder = CreateMaterializedView()
317+
builder.dialect = dialect or self.dialect
318+
return builder
319+
320+
def create_index(self, index_name: str, dialect: DialectType = None) -> "CreateIndex":
321+
"""Create a CREATE INDEX builder.
322+
323+
Args:
324+
index_name: Name of the index to create
325+
dialect: Optional SQL dialect
326+
327+
Returns:
328+
CreateIndex builder instance
329+
"""
330+
return CreateIndex(index_name, dialect=dialect or self.dialect)
331+
332+
def create_schema(self, dialect: DialectType = None) -> "CreateSchema":
333+
"""Create a CREATE SCHEMA builder.
334+
335+
Args:
336+
dialect: Optional SQL dialect
337+
338+
Returns:
339+
CreateSchema builder instance
340+
"""
341+
builder = CreateSchema()
342+
builder.dialect = dialect or self.dialect
343+
return builder
344+
345+
def drop_table(self, table_name: str, dialect: DialectType = None) -> "DropTable":
346+
"""Create a DROP TABLE builder.
347+
348+
Args:
349+
table_name: Name of the table to drop
350+
dialect: Optional SQL dialect
351+
352+
Returns:
353+
DropTable builder instance
354+
"""
355+
return DropTable(table_name, dialect=dialect or self.dialect)
356+
357+
def drop_view(self, dialect: DialectType = None) -> "DropView":
358+
"""Create a DROP VIEW builder.
359+
360+
Args:
361+
dialect: Optional SQL dialect
362+
363+
Returns:
364+
DropView builder instance
365+
"""
366+
return DropView(dialect=dialect or self.dialect)
367+
368+
def drop_index(self, index_name: str, dialect: DialectType = None) -> "DropIndex":
369+
"""Create a DROP INDEX builder.
370+
371+
Args:
372+
index_name: Name of the index to drop
373+
dialect: Optional SQL dialect
374+
375+
Returns:
376+
DropIndex builder instance
377+
"""
378+
return DropIndex(index_name, dialect=dialect or self.dialect)
379+
380+
def drop_schema(self, dialect: DialectType = None) -> "DropSchema":
381+
"""Create a DROP SCHEMA builder.
382+
383+
Args:
384+
dialect: Optional SQL dialect
385+
386+
Returns:
387+
DropSchema builder instance
388+
"""
389+
return DropSchema(dialect=dialect or self.dialect)
390+
391+
def alter_table(self, table_name: str, dialect: DialectType = None) -> "AlterTable":
392+
"""Create an ALTER TABLE builder.
393+
394+
Args:
395+
table_name: Name of the table to alter
396+
dialect: Optional SQL dialect
397+
398+
Returns:
399+
AlterTable builder instance
400+
"""
401+
builder = AlterTable(table_name)
402+
builder.dialect = dialect or self.dialect
403+
return builder
404+
405+
def rename_table(self, dialect: DialectType = None) -> "RenameTable":
406+
"""Create a RENAME TABLE builder.
407+
408+
Args:
409+
dialect: Optional SQL dialect
410+
411+
Returns:
412+
RenameTable builder instance
413+
"""
414+
builder = RenameTable()
415+
builder.dialect = dialect or self.dialect
416+
return builder
417+
418+
def comment_on(self, dialect: DialectType = None) -> "CommentOn":
419+
"""Create a COMMENT ON builder.
420+
421+
Args:
422+
dialect: Optional SQL dialect
423+
424+
Returns:
425+
CommentOn builder instance
426+
"""
427+
builder = CommentOn()
428+
builder.dialect = dialect or self.dialect
429+
return builder
430+
215431
# ===================
216432
# SQL Analysis Helpers
217433
# ===================
@@ -363,39 +579,39 @@ def __getattr__(self, name: str) -> Column:
363579
# ===================
364580

365581
@staticmethod
366-
def raw(sql_fragment: str) -> exp.Expression:
367-
"""Create a raw SQL expression from a string fragment.
582+
def raw(sql_fragment: str, **parameters: Any) -> "Union[exp.Expression, SQL]":
583+
"""Create a raw SQL expression from a string fragment with optional parameters.
368584
369585
This method makes it explicit that you are passing raw SQL that should
370586
be parsed and included directly in the query. Useful for complex expressions,
371587
database-specific functions, or when you need precise control over the SQL.
372588
373589
Args:
374590
sql_fragment: Raw SQL string to parse into an expression.
591+
**parameters: Named parameters for parameter binding.
375592
376593
Returns:
377-
SQLGlot expression from the parsed SQL fragment.
594+
SQLGlot expression from the parsed SQL fragment (if no parameters).
595+
SQL statement object (if parameters provided).
378596
379597
Raises:
380598
SQLBuilderError: If the SQL fragment cannot be parsed.
381599
382600
Example:
383601
```python
384-
# Raw column expression with alias
385-
query = sql.select(
386-
sql.raw("user.id AS u_id"), "name"
387-
).from_("users")
602+
# Raw expression without parameters (current behavior)
603+
expr = sql.raw("COALESCE(name, 'Unknown')")
388604
389-
# Raw function call
390-
query = sql.select(
391-
sql.raw("COALESCE(name, 'Unknown')")
392-
).from_("users")
605+
# Raw SQL with named parameters (new functionality)
606+
stmt = sql.raw(
607+
"LOWER(name) LIKE LOWER(:pattern)", pattern=f"%{query}%"
608+
)
393609
394-
# Raw complex expression
395-
query = (
396-
sql.select("*")
397-
.from_("orders")
398-
.where(sql.raw("DATE(created_at) = CURRENT_DATE"))
610+
# Raw complex expression with parameters
611+
expr = sql.raw(
612+
"price BETWEEN :min_price AND :max_price",
613+
min_price=100,
614+
max_price=500,
399615
)
400616
401617
# Raw window function
@@ -407,16 +623,23 @@ def raw(sql_fragment: str) -> exp.Expression:
407623
).from_("employees")
408624
```
409625
"""
410-
try:
411-
parsed: Optional[exp.Expression] = exp.maybe_parse(sql_fragment)
412-
if parsed is not None:
413-
return parsed
414-
if sql_fragment.strip().replace("_", "").replace(".", "").isalnum():
415-
return exp.to_identifier(sql_fragment)
416-
return exp.Literal.string(sql_fragment)
417-
except Exception as e:
418-
msg = f"Failed to parse raw SQL fragment '{sql_fragment}': {e}"
419-
raise SQLBuilderError(msg) from e
626+
if not parameters:
627+
# Original behavior - return pure expression
628+
try:
629+
parsed: Optional[exp.Expression] = exp.maybe_parse(sql_fragment)
630+
if parsed is not None:
631+
return parsed
632+
if sql_fragment.strip().replace("_", "").replace(".", "").isalnum():
633+
return exp.to_identifier(sql_fragment)
634+
return exp.Literal.string(sql_fragment)
635+
except Exception as e:
636+
msg = f"Failed to parse raw SQL fragment '{sql_fragment}': {e}"
637+
raise SQLBuilderError(msg) from e
638+
639+
# New behavior - return SQL statement with parameters
640+
from sqlspec.core.statement import SQL
641+
642+
return SQL(sql_fragment, parameters)
420643

421644
# ===================
422645
# Aggregate Functions

sqlspec/builder/_insert.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,18 @@ def values(self, *values: Any) -> "Self":
119119
msg = ERR_MSG_VALUES_COLUMNS_MISMATCH.format(values_len=len(values), columns_len=len(self._columns))
120120
raise SQLBuilderError(msg)
121121

122-
param_names = [self._add_parameter(value) for value in values]
122+
param_names = []
123+
for i, value in enumerate(values):
124+
# Try to use column name if available, otherwise use position-based name
125+
if self._columns and i < len(self._columns):
126+
column_name = (
127+
str(self._columns[i]).split(".")[-1] if "." in str(self._columns[i]) else str(self._columns[i])
128+
)
129+
param_name = self._generate_unique_parameter_name(column_name)
130+
else:
131+
param_name = self._generate_unique_parameter_name(f"value_{i + 1}")
132+
_, param_name = self.add_parameter(value, name=param_name)
133+
param_names.append(param_name)
123134
value_placeholders = tuple(exp.var(name) for name in param_names)
124135

125136
current_values_expression = insert_expr.args.get("expression")

sqlspec/builder/_parsing_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,11 @@ def parse_condition_expression(
109109
if value is None:
110110
return exp.Is(this=column_expr, expression=exp.null())
111111
if builder and has_parameter_builder(builder):
112-
_, param_name = builder.add_parameter(value)
112+
from sqlspec.builder.mixins._where_clause import _extract_column_name
113+
114+
column_name = _extract_column_name(column)
115+
param_name = builder._generate_unique_parameter_name(column_name)
116+
_, param_name = builder.add_parameter(value, name=param_name)
113117
return exp.EQ(this=column_expr, expression=exp.Placeholder(this=param_name))
114118
if isinstance(value, str):
115119
return exp.EQ(this=column_expr, expression=exp.convert(value))

sqlspec/builder/mixins/_insert_operations.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,21 @@ def values(self, *values: Any) -> Self:
7878
except AttributeError:
7979
pass
8080
row_exprs = []
81-
for v in values:
81+
for i, v in enumerate(values):
8282
if isinstance(v, exp.Expression):
8383
row_exprs.append(v)
8484
else:
85-
_, param_name = self.add_parameter(v) # type: ignore[attr-defined]
85+
# Try to use column name if available, otherwise use position-based name
86+
try:
87+
_columns = self._columns # type: ignore[attr-defined]
88+
if _columns and i < len(_columns):
89+
column_name = str(_columns[i]).split(".")[-1] if "." in str(_columns[i]) else str(_columns[i])
90+
param_name = self._generate_unique_parameter_name(column_name) # type: ignore[attr-defined]
91+
else:
92+
param_name = self._generate_unique_parameter_name(f"value_{i + 1}") # type: ignore[attr-defined]
93+
except AttributeError:
94+
param_name = self._generate_unique_parameter_name(f"value_{i + 1}") # type: ignore[attr-defined]
95+
_, param_name = self.add_parameter(v, name=param_name) # type: ignore[attr-defined]
8696
row_exprs.append(exp.var(param_name))
8797
values_expr = exp.Values(expressions=[row_exprs])
8898
self._expression.set("expression", values_expr)

0 commit comments

Comments
 (0)