13
13
from sqlglot .optimizer import optimize
14
14
from typing_extensions import Self
15
15
16
- from sqlspec .core .cache import CacheKey , get_cache_config , get_default_cache
16
+ from sqlspec .core .cache import get_cache , get_cache_config
17
17
from sqlspec .core .hashing import hash_optimized_expression
18
18
from sqlspec .core .parameters import ParameterStyle , ParameterStyleConfig
19
19
from sqlspec .core .statement import SQL , StatementConfig
@@ -91,6 +91,36 @@ def _initialize_expression(self) -> None:
91
91
"QueryBuilder._create_base_expression must return a valid sqlglot expression."
92
92
)
93
93
94
+ def get_expression (self ) -> Optional [exp .Expression ]:
95
+ """Get expression reference (no copy).
96
+
97
+ Returns:
98
+ The current SQLGlot expression or None if not set
99
+ """
100
+ return self ._expression
101
+
102
+ def set_expression (self , expression : exp .Expression ) -> None :
103
+ """Set expression with validation.
104
+
105
+ Args:
106
+ expression: SQLGlot expression to set
107
+
108
+ Raises:
109
+ TypeError: If expression is not a SQLGlot Expression
110
+ """
111
+ if not isinstance (expression , exp .Expression ):
112
+ msg = f"Expected Expression, got { type (expression )} "
113
+ raise TypeError (msg )
114
+ self ._expression = expression
115
+
116
+ def has_expression (self ) -> bool :
117
+ """Check if expression exists.
118
+
119
+ Returns:
120
+ True if expression is set, False otherwise
121
+ """
122
+ return self ._expression is not None
123
+
94
124
@abstractmethod
95
125
def _create_base_expression (self ) -> exp .Expression :
96
126
"""Create the base sqlglot expression for the specific query type.
@@ -307,12 +337,13 @@ def with_cte(self: Self, alias: str, query: "Union[QueryBuilder, exp.Select, str
307
337
cte_select_expression : exp .Select
308
338
309
339
if isinstance (query , QueryBuilder ):
310
- if query ._expression is None :
340
+ query_expr = query .get_expression ()
341
+ if query_expr is None :
311
342
self ._raise_sql_builder_error ("CTE query builder has no expression." )
312
- if not isinstance (query . _expression , exp .Select ):
313
- msg = f"CTE query builder expression must be a Select, got { type (query . _expression ).__name__ } ."
343
+ if not isinstance (query_expr , exp .Select ):
344
+ msg = f"CTE query builder expression must be a Select, got { type (query_expr ).__name__ } ."
314
345
self ._raise_sql_builder_error (msg )
315
- cte_select_expression = query . _expression
346
+ cte_select_expression = query_expr
316
347
param_mapping = self ._merge_cte_parameters (alias , query .parameters )
317
348
updated_expression = self ._update_placeholders_in_expression (cte_select_expression , param_mapping )
318
349
if not isinstance (updated_expression , exp .Select ):
@@ -398,9 +429,8 @@ def _optimize_expression(self, expression: exp.Expression) -> exp.Expression:
398
429
expression , dialect = dialect_name , schema = self .schema , optimizer_settings = optimizer_settings
399
430
)
400
431
401
- cache_key_obj = CacheKey ((cache_key ,))
402
- unified_cache = get_default_cache ()
403
- cached_optimized = unified_cache .get (cache_key_obj )
432
+ cache = get_cache ()
433
+ cached_optimized = cache .get ("optimized" , cache_key )
404
434
if cached_optimized :
405
435
return cast ("exp.Expression" , cached_optimized )
406
436
@@ -409,7 +439,7 @@ def _optimize_expression(self, expression: exp.Expression) -> exp.Expression:
409
439
expression , schema = self .schema , dialect = self .dialect_name , optimizer_settings = optimizer_settings
410
440
)
411
441
412
- unified_cache .put (cache_key_obj , optimized )
442
+ cache .put ("optimized" , cache_key , optimized )
413
443
414
444
except Exception :
415
445
return expression
@@ -430,15 +460,14 @@ def to_statement(self, config: "Optional[StatementConfig]" = None) -> "SQL":
430
460
return self ._to_statement (config )
431
461
432
462
cache_key_str = self ._generate_builder_cache_key (config )
433
- cache_key = CacheKey ((cache_key_str ,))
434
463
435
- unified_cache = get_default_cache ()
436
- cached_sql = unified_cache .get (cache_key )
464
+ cache = get_cache ()
465
+ cached_sql = cache .get ("builder" , cache_key_str )
437
466
if cached_sql is not None :
438
467
return cast ("SQL" , cached_sql )
439
468
440
469
sql_statement = self ._to_statement (config )
441
- unified_cache .put (cache_key , sql_statement )
470
+ cache .put ("builder" , cache_key_str , sql_statement )
442
471
443
472
return sql_statement
444
473
@@ -531,3 +560,16 @@ def _merge_sql_object_parameters(self, sql_obj: Any) -> None:
531
560
def parameters (self ) -> dict [str , Any ]:
532
561
"""Public access to query parameters."""
533
562
return self ._parameters
563
+
564
+ def set_parameters (self , parameters : dict [str , Any ]) -> None :
565
+ """Set query parameters (public API)."""
566
+ self ._parameters = parameters .copy ()
567
+
568
+ @property
569
+ def with_ctes (self ) -> "dict[str, exp.CTE]" :
570
+ """Get WITH clause CTEs (public API)."""
571
+ return dict (self ._with_ctes )
572
+
573
+ def generate_unique_parameter_name (self , base_name : str ) -> str :
574
+ """Generate unique parameter name (public API)."""
575
+ return self ._generate_unique_parameter_name (base_name )
0 commit comments