Skip to content

Commit f49a575

Browse files
authored
feat: remove private variable usage and improve public API (#76)
Implement improvements in statement caching and update private usage reporting settings. Fix remaining private calls to ensure compliance with type checking.
1 parent dc8e5e9 commit f49a575

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+1418
-1357
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,8 @@ exclude = ["**/node_modules", "**/__pycache__", ".venv", "tools", "docs", "tmp",
360360
include = ["sqlspec", "tests"]
361361
pythonVersion = "3.9"
362362
reportMissingTypeStubs = false
363-
reportPrivateImportUsage = false
364-
reportPrivateUsage = false
363+
reportPrivateImportUsage = true
364+
reportPrivateUsage = true
365365
reportTypedDictNotRequiredAccess = false
366366
reportUnknownArgumentType = false
367367
reportUnnecessaryCast = false

sqlspec/_sql.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def __call__(self, statement: str, dialect: DialectType = None) -> "Any":
170170
actual_type_str == "WITH" and parsed_expr.this and isinstance(parsed_expr.this, exp.Select)
171171
):
172172
builder = Select(dialect=dialect or self.dialect)
173-
builder._expression = parsed_expr
173+
builder.set_expression(parsed_expr)
174174
return builder
175175

176176
if actual_type_str in {"INSERT", "UPDATE", "DELETE"} and parsed_expr.args.get("returning") is not None:
@@ -451,7 +451,7 @@ def _populate_insert_from_sql(self, builder: "Insert", sql_string: str) -> "Inse
451451
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)
452452

453453
if isinstance(parsed_expr, exp.Insert):
454-
builder._expression = parsed_expr
454+
builder.set_expression(parsed_expr)
455455
return builder
456456

457457
if isinstance(parsed_expr, exp.Select):
@@ -470,7 +470,7 @@ def _populate_select_from_sql(self, builder: "Select", sql_string: str) -> "Sele
470470
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)
471471

472472
if isinstance(parsed_expr, exp.Select):
473-
builder._expression = parsed_expr
473+
builder.set_expression(parsed_expr)
474474
return builder
475475

476476
logger.warning("Cannot create SELECT from %s statement", type(parsed_expr).__name__)
@@ -485,7 +485,7 @@ def _populate_update_from_sql(self, builder: "Update", sql_string: str) -> "Upda
485485
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)
486486

487487
if isinstance(parsed_expr, exp.Update):
488-
builder._expression = parsed_expr
488+
builder.set_expression(parsed_expr)
489489
return builder
490490

491491
logger.warning("Cannot create UPDATE from %s statement", type(parsed_expr).__name__)
@@ -500,7 +500,7 @@ def _populate_delete_from_sql(self, builder: "Delete", sql_string: str) -> "Dele
500500
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)
501501

502502
if isinstance(parsed_expr, exp.Delete):
503-
builder._expression = parsed_expr
503+
builder.set_expression(parsed_expr)
504504
return builder
505505

506506
logger.warning("Cannot create DELETE from %s statement", type(parsed_expr).__name__)
@@ -515,7 +515,7 @@ def _populate_merge_from_sql(self, builder: "Merge", sql_string: str) -> "Merge"
515515
parsed_expr: exp.Expression = exp.maybe_parse(sql_string, dialect=self.dialect)
516516

517517
if isinstance(parsed_expr, exp.Merge):
518-
builder._expression = parsed_expr
518+
builder.set_expression(parsed_expr)
519519
return builder
520520

521521
logger.warning("Cannot create MERGE from %s statement", type(parsed_expr).__name__)
@@ -724,19 +724,15 @@ def raw(sql_fragment: str, **parameters: Any) -> "Union[exp.Expression, SQL]":
724724
if not parameters:
725725
try:
726726
parsed: exp.Expression = exp.maybe_parse(sql_fragment)
727-
return parsed
728-
if sql_fragment.strip().replace("_", "").replace(".", "").isalnum():
729-
return exp.to_identifier(sql_fragment)
730-
return exp.Literal.string(sql_fragment)
731727
except Exception as e:
732728
msg = f"Failed to parse raw SQL fragment '{sql_fragment}': {e}"
733729
raise SQLBuilderError(msg) from e
730+
return parsed
734731

735732
return SQL(sql_fragment, parameters)
736733

737-
@staticmethod
738734
def count(
739-
column: Union[str, exp.Expression, "ExpressionWrapper", "Case", "Column"] = "*", distinct: bool = False
735+
self, column: Union[str, exp.Expression, "ExpressionWrapper", "Case", "Column"] = "*", distinct: bool = False
740736
) -> AggregateExpression:
741737
"""Create a COUNT expression.
742738
@@ -750,7 +746,7 @@ def count(
750746
if isinstance(column, str) and column == "*":
751747
expr = exp.Count(this=exp.Star(), distinct=distinct)
752748
else:
753-
col_expr = SQLFactory._extract_expression(column)
749+
col_expr = self._extract_expression(column)
754750
expr = exp.Count(this=col_expr, distinct=distinct)
755751
return AggregateExpression(expr)
756752

@@ -1068,11 +1064,11 @@ def _extract_expression(value: Any) -> exp.Expression:
10681064
if isinstance(value, str):
10691065
return exp.column(value)
10701066
if isinstance(value, Column):
1071-
return value._expression
1067+
return value.sqlglot_expression
10721068
if isinstance(value, ExpressionWrapper):
10731069
return value.expression
10741070
if isinstance(value, Case):
1075-
return exp.Case(ifs=value._conditions, default=value._default)
1071+
return exp.Case(ifs=value.conditions, default=value.default)
10761072
if isinstance(value, exp.Expression):
10771073
return value
10781074
return exp.convert(value)

sqlspec/adapters/adbc/driver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def _execute_many(self, cursor: "Cursor", statement: SQL) -> "ExecutionResult":
521521

522522
try:
523523
if not prepared_parameters:
524-
cursor._rowcount = 0
524+
cursor._rowcount = 0 # pyright: ignore[reportPrivateUsage]
525525
row_count = 0
526526
elif isinstance(prepared_parameters, list) and prepared_parameters:
527527
processed_params = []
@@ -596,7 +596,7 @@ def _execute_script(self, cursor: "Cursor", statement: "SQL") -> "ExecutionResul
596596
Execution result with statement counts
597597
"""
598598
if statement.is_script:
599-
sql = statement._raw_sql
599+
sql = statement.raw_sql
600600
prepared_parameters: list[Any] = []
601601
else:
602602
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)

sqlspec/adapters/oracledb/driver.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,11 @@ def _execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult":
286286
msg = "execute_many requires parameters"
287287
raise ValueError(msg)
288288

289+
# Oracle-specific fix: Ensure parameters are in list format for executemany
290+
# Oracle expects a list of sequences, not a tuple of sequences
291+
if isinstance(prepared_parameters, tuple):
292+
prepared_parameters = list(prepared_parameters)
293+
289294
cursor.executemany(sql, prepared_parameters)
290295

291296
# Calculate affected rows based on parameter count

sqlspec/adapters/psycopg/config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,7 @@ def _close_pool(self) -> None:
173173
logger.info("Closing Psycopg connection pool", extra={"adapter": "psycopg"})
174174

175175
try:
176-
if hasattr(self.pool_instance, "_closed"):
177-
self.pool_instance._closed = True
176+
self.pool_instance._closed = True # pyright: ignore[reportPrivateUsage]
178177

179178
self.pool_instance.close()
180179
logger.info("Psycopg connection pool closed successfully", extra={"adapter": "psycopg"})
@@ -350,8 +349,7 @@ async def _close_pool(self) -> None:
350349
return
351350

352351
try:
353-
if hasattr(self.pool_instance, "_closed"):
354-
self.pool_instance._closed = True
352+
self.pool_instance._closed = True # pyright: ignore[reportPrivateUsage]
355353

356354
await self.pool_instance.close()
357355
finally:

sqlspec/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
)
1616
from sqlspec.core.cache import (
1717
CacheConfig,
18-
CacheStatsAggregate,
1918
get_cache_config,
20-
get_cache_stats,
19+
get_cache_statistics,
2120
log_cache_stats,
2221
reset_cache_stats,
2322
update_cache_config,
@@ -532,13 +531,13 @@ def update_cache_config(config: CacheConfig) -> None:
532531
update_cache_config(config)
533532

534533
@staticmethod
535-
def get_cache_stats() -> CacheStatsAggregate:
534+
def get_cache_stats() -> "dict[str, Any]":
536535
"""Get current cache statistics.
537536
538537
Returns:
539538
Cache statistics object with detailed metrics.
540539
"""
541-
return get_cache_stats()
540+
return get_cache_statistics()
542541

543542
@staticmethod
544543
def reset_cache_stats() -> None:

sqlspec/builder/_base.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sqlglot.optimizer import optimize
1414
from typing_extensions import Self
1515

16-
from sqlspec.core.cache import CacheKey, get_cache_config, get_default_cache
16+
from sqlspec.core.cache import get_cache, get_cache_config
1717
from sqlspec.core.hashing import hash_optimized_expression
1818
from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
1919
from sqlspec.core.statement import SQL, StatementConfig
@@ -91,6 +91,36 @@ def _initialize_expression(self) -> None:
9191
"QueryBuilder._create_base_expression must return a valid sqlglot expression."
9292
)
9393

94+
def get_expression(self) -> Optional[exp.Expression]:
95+
"""Get expression reference (no copy).
96+
97+
Returns:
98+
The current SQLGlot expression or None if not set
99+
"""
100+
return self._expression
101+
102+
def set_expression(self, expression: exp.Expression) -> None:
103+
"""Set expression with validation.
104+
105+
Args:
106+
expression: SQLGlot expression to set
107+
108+
Raises:
109+
TypeError: If expression is not a SQLGlot Expression
110+
"""
111+
if not isinstance(expression, exp.Expression):
112+
msg = f"Expected Expression, got {type(expression)}"
113+
raise TypeError(msg)
114+
self._expression = expression
115+
116+
def has_expression(self) -> bool:
117+
"""Check if expression exists.
118+
119+
Returns:
120+
True if expression is set, False otherwise
121+
"""
122+
return self._expression is not None
123+
94124
@abstractmethod
95125
def _create_base_expression(self) -> exp.Expression:
96126
"""Create the base sqlglot expression for the specific query type.
@@ -307,12 +337,13 @@ def with_cte(self: Self, alias: str, query: "Union[QueryBuilder, exp.Select, str
307337
cte_select_expression: exp.Select
308338

309339
if isinstance(query, QueryBuilder):
310-
if query._expression is None:
340+
query_expr = query.get_expression()
341+
if query_expr is None:
311342
self._raise_sql_builder_error("CTE query builder has no expression.")
312-
if not isinstance(query._expression, exp.Select):
313-
msg = f"CTE query builder expression must be a Select, got {type(query._expression).__name__}."
343+
if not isinstance(query_expr, exp.Select):
344+
msg = f"CTE query builder expression must be a Select, got {type(query_expr).__name__}."
314345
self._raise_sql_builder_error(msg)
315-
cte_select_expression = query._expression
346+
cte_select_expression = query_expr
316347
param_mapping = self._merge_cte_parameters(alias, query.parameters)
317348
updated_expression = self._update_placeholders_in_expression(cte_select_expression, param_mapping)
318349
if not isinstance(updated_expression, exp.Select):
@@ -398,9 +429,8 @@ def _optimize_expression(self, expression: exp.Expression) -> exp.Expression:
398429
expression, dialect=dialect_name, schema=self.schema, optimizer_settings=optimizer_settings
399430
)
400431

401-
cache_key_obj = CacheKey((cache_key,))
402-
unified_cache = get_default_cache()
403-
cached_optimized = unified_cache.get(cache_key_obj)
432+
cache = get_cache()
433+
cached_optimized = cache.get("optimized", cache_key)
404434
if cached_optimized:
405435
return cast("exp.Expression", cached_optimized)
406436

@@ -409,7 +439,7 @@ def _optimize_expression(self, expression: exp.Expression) -> exp.Expression:
409439
expression, schema=self.schema, dialect=self.dialect_name, optimizer_settings=optimizer_settings
410440
)
411441

412-
unified_cache.put(cache_key_obj, optimized)
442+
cache.put("optimized", cache_key, optimized)
413443

414444
except Exception:
415445
return expression
@@ -430,15 +460,14 @@ def to_statement(self, config: "Optional[StatementConfig]" = None) -> "SQL":
430460
return self._to_statement(config)
431461

432462
cache_key_str = self._generate_builder_cache_key(config)
433-
cache_key = CacheKey((cache_key_str,))
434463

435-
unified_cache = get_default_cache()
436-
cached_sql = unified_cache.get(cache_key)
464+
cache = get_cache()
465+
cached_sql = cache.get("builder", cache_key_str)
437466
if cached_sql is not None:
438467
return cast("SQL", cached_sql)
439468

440469
sql_statement = self._to_statement(config)
441-
unified_cache.put(cache_key, sql_statement)
470+
cache.put("builder", cache_key_str, sql_statement)
442471

443472
return sql_statement
444473

@@ -531,3 +560,16 @@ def _merge_sql_object_parameters(self, sql_obj: Any) -> None:
531560
def parameters(self) -> dict[str, Any]:
532561
"""Public access to query parameters."""
533562
return self._parameters
563+
564+
def set_parameters(self, parameters: dict[str, Any]) -> None:
565+
"""Set query parameters (public API)."""
566+
self._parameters = parameters.copy()
567+
568+
@property
569+
def with_ctes(self) -> "dict[str, exp.CTE]":
570+
"""Get WITH clause CTEs (public API)."""
571+
return dict(self._with_ctes)
572+
573+
def generate_unique_parameter_name(self, base_name: str) -> str:
574+
"""Generate unique parameter name (public API)."""
575+
return self._generate_unique_parameter_name(base_name)

sqlspec/builder/_column.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,15 @@ def __hash__(self) -> int:
254254
"""Hash based on table and column name."""
255255
return hash((self.table, self.name))
256256

257+
@property
258+
def sqlglot_expression(self) -> exp.Expression:
259+
"""Get the underlying SQLGlot expression (public API).
260+
261+
Returns:
262+
The SQLGlot expression for this column
263+
"""
264+
return self._expression
265+
257266

258267
class FunctionColumn:
259268
"""Represents the result of a SQL function call on a column."""

sqlspec/builder/_ddl.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -973,10 +973,10 @@ def _create_base_expression(self) -> exp.Expression:
973973
select_expr = self._select_query.expression
974974
select_parameters = self._select_query.parameters
975975
elif isinstance(self._select_query, Select):
976-
select_expr = self._select_query._expression
977-
select_parameters = self._select_query._parameters
976+
select_expr = self._select_query.get_expression()
977+
select_parameters = self._select_query.parameters
978978

979-
with_ctes = self._select_query._with_ctes
979+
with_ctes = self._select_query.with_ctes
980980
if with_ctes and select_expr and isinstance(select_expr, exp.Select):
981981
for alias, cte in with_ctes.items():
982982
if has_with_method(select_expr):
@@ -1100,8 +1100,8 @@ def _create_base_expression(self) -> exp.Expression:
11001100
select_expr = self._select_query.expression
11011101
select_parameters = self._select_query.parameters
11021102
elif isinstance(self._select_query, Select):
1103-
select_expr = self._select_query._expression
1104-
select_parameters = self._select_query._parameters
1103+
select_expr = self._select_query.get_expression()
1104+
select_parameters = self._select_query.parameters
11051105
elif isinstance(self._select_query, str):
11061106
select_expr = exp.maybe_parse(self._select_query)
11071107
select_parameters = None
@@ -1198,8 +1198,8 @@ def _create_base_expression(self) -> exp.Expression:
11981198
select_expr = self._select_query.expression
11991199
select_parameters = self._select_query.parameters
12001200
elif isinstance(self._select_query, Select):
1201-
select_expr = self._select_query._expression
1202-
select_parameters = self._select_query._parameters
1201+
select_expr = self._select_query.get_expression()
1202+
select_parameters = self._select_query.parameters
12031203
elif isinstance(self._select_query, str):
12041204
select_expr = exp.maybe_parse(self._select_query)
12051205
select_parameters = None

0 commit comments

Comments
 (0)