Skip to content

Commit cba08ac

Browse files
run primitive parameterised query tests
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent a61d6f3 commit cba08ac

File tree

2 files changed

+184
-27
lines changed

2 files changed

+184
-27
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,9 @@ def execute_command(
439439
sea_parameters.append(
440440
StatementParameter(
441441
name=param.name,
442-
value=param.value,
442+
value=(
443+
None if param.value is None else param.value.stringValue
444+
),
443445
type=param.type,
444446
)
445447
)

tests/e2e/test_parameterized_queries.py

Lines changed: 181 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,13 @@ def patch_server_supports_native_params(self, supports_native_params: bool = Tru
168168
finally:
169169
pass
170170

171-
def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle, target_column):
171+
def _inline_roundtrip(
172+
self,
173+
params: dict,
174+
paramstyle: ParamStyle,
175+
target_column,
176+
extra_params: dict = {},
177+
):
172178
"""This INSERT, SELECT, DELETE dance is necessary because simply selecting
173179
```
174180
"SELECT %(param)s"
@@ -183,7 +189,9 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle, target_column)
183189
SELECT_QUERY = f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1"
184190
DELETE_QUERY = "DELETE FROM pysql_e2e_inline_param_test_table"
185191

186-
with self.connection(extra_params={"use_inline_params": True}) as conn:
192+
with self.connection(
193+
extra_params={"use_inline_params": True, **extra_params}
194+
) as conn:
187195
with conn.cursor() as cursor:
188196
cursor.execute(INSERT_QUERY, parameters=params)
189197
with conn.cursor() as cursor:
@@ -198,14 +206,17 @@ def _native_roundtrip(
198206
parameters: Union[Dict, List[Dict]],
199207
paramstyle: ParamStyle,
200208
parameter_structure: ParameterStructure,
209+
extra_params: dict = {},
201210
):
202211
if parameter_structure == ParameterStructure.POSITIONAL:
203212
_query = self.POSITIONAL_PARAMSTYLE_QUERY
204213
elif paramstyle == ParamStyle.NAMED:
205214
_query = self.NAMED_PARAMSTYLE_QUERY
206215
elif paramstyle == ParamStyle.PYFORMAT:
207216
_query = self.PYFORMAT_PARAMSTYLE_QUERY
208-
with self.connection(extra_params={"use_inline_params": False}) as conn:
217+
with self.connection(
218+
extra_params={"use_inline_params": False, **extra_params}
219+
) as conn:
209220
with conn.cursor() as cursor:
210221
cursor.execute(_query, parameters=parameters)
211222
return cursor.fetchone()
@@ -216,6 +227,7 @@ def _get_one_result(
216227
approach: ParameterApproach = ParameterApproach.NONE,
217228
paramstyle: ParamStyle = ParamStyle.NONE,
218229
parameter_structure: ParameterStructure = ParameterStructure.NONE,
230+
extra_params: dict = {},
219231
):
220232
"""When approach is INLINE then we use %(param)s paramstyle and a connection with use_inline_params=True
221233
When approach is NATIVE then we use :param paramstyle and a connection with use_inline_params=False
@@ -228,12 +240,16 @@ def _get_one_result(
228240
params,
229241
paramstyle=ParamStyle.PYFORMAT,
230242
target_column=self._get_inline_table_column(params.get("p")),
243+
extra_params=extra_params,
231244
)
232245
elif approach == ParameterApproach.NATIVE:
233246
# native mode can use either ParamStyle.NAMED or ParamStyle.PYFORMAT
234247
# native mode can use either ParameterStructure.NAMED or ParameterStructure.POSITIONAL
235248
return self._native_roundtrip(
236-
params, paramstyle=paramstyle, parameter_structure=parameter_structure
249+
params,
250+
paramstyle=paramstyle,
251+
parameter_structure=parameter_structure,
252+
extra_params=extra_params,
237253
)
238254

239255
def _quantize(self, input: Union[float, int], place_value=2) -> Decimal:
@@ -379,7 +395,20 @@ def test_dbsqlparameter_single(
379395
assert self._eq(result.col, primitive)
380396

381397
@pytest.mark.parametrize("use_inline_params", (True, False, "silent"))
382-
def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog):
398+
@pytest.mark.parametrize(
399+
"extra_params",
400+
[
401+
{},
402+
{
403+
"use_sea": True,
404+
"use_cloud_fetch": False,
405+
"enable_query_result_lz4_compression": False,
406+
},
407+
],
408+
)
409+
def test_use_inline_off_by_default_with_warning(
410+
self, use_inline_params, caplog, extra_params
411+
):
383412
"""
384413
use_inline_params should be False by default.
385414
If a user explicitly sets use_inline_params, don't warn them about it.
@@ -389,7 +418,7 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog)
389418
{"use_inline_params": use_inline_params} if use_inline_params else {}
390419
)
391420

392-
with self.connection(extra_params=extra_args) as conn:
421+
with self.connection(extra_params={**extra_args, **extra_params}) as conn:
393422
with conn.cursor() as cursor:
394423
with self.patch_server_supports_native_params(
395424
supports_native_params=True
@@ -404,9 +433,20 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog)
404433
"Consider using native parameters." not in caplog.text
405434
), "Log message should not be supressed"
406435

407-
def test_positional_native_params_with_defaults(self):
436+
@pytest.mark.parametrize(
437+
"extra_params",
438+
[
439+
{},
440+
{
441+
"use_sea": True,
442+
"use_cloud_fetch": False,
443+
"enable_query_result_lz4_compression": False,
444+
},
445+
],
446+
)
447+
def test_positional_native_params_with_defaults(self, extra_params):
408448
query = "SELECT ? col"
409-
with self.cursor() as cursor:
449+
with self.cursor(extra_params=extra_params) as cursor:
410450
result = cursor.execute(query, parameters=[1]).fetchone()
411451

412452
assert result.col == 1
@@ -422,19 +462,43 @@ def test_positional_native_params_with_defaults(self):
422462
["foo", "bar", "baz"],
423463
),
424464
)
425-
def test_positional_native_multiple(self, params):
465+
@pytest.mark.parametrize(
466+
"extra_params",
467+
[
468+
{},
469+
{
470+
"use_sea": True,
471+
"use_cloud_fetch": False,
472+
"enable_query_result_lz4_compression": False,
473+
},
474+
],
475+
)
476+
def test_positional_native_multiple(self, params, extra_params):
426477
query = "SELECT ? `foo`, ? `bar`, ? `baz`"
427478

428-
with self.cursor(extra_params={"use_inline_params": False}) as cursor:
479+
with self.cursor(
480+
extra_params={"use_inline_params": False, **extra_params}
481+
) as cursor:
429482
result = cursor.execute(query, params).fetchone()
430483

431484
expected = [i.value if isinstance(i, DbsqlParameterBase) else i for i in params]
432485
outcome = [result.foo, result.bar, result.baz]
433486

434487
assert set(outcome) == set(expected)
435488

436-
def test_readme_example(self):
437-
with self.cursor() as cursor:
489+
@pytest.mark.parametrize(
490+
"extra_params",
491+
[
492+
{},
493+
{
494+
"use_sea": True,
495+
"use_cloud_fetch": False,
496+
"enable_query_result_lz4_compression": False,
497+
},
498+
],
499+
)
500+
def test_readme_example(self, extra_params):
501+
with self.cursor(extra_params=extra_params) as cursor:
438502
result = cursor.execute(
439503
"SELECT :param `p`, * FROM RANGE(10)", {"param": "foo"}
440504
).fetchall()
@@ -498,19 +562,43 @@ def test_native_recursive_complex_type(
498562
class TestInlineParameterSyntax(PySQLPytestTestCase):
499563
"""The inline parameter approach uses pyformat markers"""
500564

501-
def test_params_as_dict(self):
565+
@pytest.mark.parametrize(
566+
"extra_params",
567+
[
568+
{},
569+
{
570+
"use_sea": True,
571+
"use_cloud_fetch": False,
572+
"enable_query_result_lz4_compression": False,
573+
},
574+
],
575+
)
576+
def test_params_as_dict(self, extra_params):
502577
query = "SELECT %(foo)s foo, %(bar)s bar, %(baz)s baz"
503578
params = {"foo": 1, "bar": 2, "baz": 3}
504579

505-
with self.connection(extra_params={"use_inline_params": True}) as conn:
580+
with self.connection(
581+
extra_params={"use_inline_params": True, **extra_params}
582+
) as conn:
506583
with conn.cursor() as cursor:
507584
result = cursor.execute(query, parameters=params).fetchone()
508585

509586
assert result.foo == 1
510587
assert result.bar == 2
511588
assert result.baz == 3
512589

513-
def test_params_as_sequence(self):
590+
@pytest.mark.parametrize(
591+
"extra_params",
592+
[
593+
{},
594+
{
595+
"use_sea": True,
596+
"use_cloud_fetch": False,
597+
"enable_query_result_lz4_compression": False,
598+
},
599+
],
600+
)
601+
def test_params_as_sequence(self, extra_params):
514602
"""One side-effect of ParamEscaper using Python string interpolation to inline the values
515603
is that it can work with "ordinal" parameters, but only if a user writes parameter markers
516604
that are not defined with PEP-249. This test exists to prove that it works in the ideal case.
@@ -520,27 +608,53 @@ def test_params_as_sequence(self):
520608
query = "SELECT %s foo, %s bar, %s baz"
521609
params = (1, 2, 3)
522610

523-
with self.connection(extra_params={"use_inline_params": True}) as conn:
611+
with self.connection(
612+
extra_params={"use_inline_params": True, **extra_params}
613+
) as conn:
524614
with conn.cursor() as cursor:
525615
result = cursor.execute(query, parameters=params).fetchone()
526616
assert result.foo == 1
527617
assert result.bar == 2
528618
assert result.baz == 3
529619

530-
def test_inline_ordinals_can_break_sql(self):
620+
@pytest.mark.parametrize(
621+
"extra_params",
622+
[
623+
{},
624+
{
625+
"use_sea": True,
626+
"use_cloud_fetch": False,
627+
"enable_query_result_lz4_compression": False,
628+
},
629+
],
630+
)
631+
def test_inline_ordinals_can_break_sql(self, extra_params):
531632
"""With inline mode, ordinal parameters can break the SQL syntax
532633
because `%` symbols are used to wildcard match within LIKE statements. This test
533634
just proves that's the case.
534635
"""
535636
query = "SELECT 'samsonite', %s WHERE 'samsonite' LIKE '%sonite'"
536637
params = ["luggage"]
537-
with self.cursor(extra_params={"use_inline_params": True}) as cursor:
638+
with self.cursor(
639+
extra_params={"use_inline_params": True, **extra_params}
640+
) as cursor:
538641
with pytest.raises(
539642
TypeError, match="not enough arguments for format string"
540643
):
541644
cursor.execute(query, parameters=params)
542645

543-
def test_inline_named_dont_break_sql(self):
646+
@pytest.mark.parametrize(
647+
"extra_params",
648+
[
649+
{},
650+
{
651+
"use_sea": True,
652+
"use_cloud_fetch": False,
653+
"enable_query_result_lz4_compression": False,
654+
},
655+
],
656+
)
657+
def test_inline_named_dont_break_sql(self, extra_params):
544658
"""With inline mode, ordinal parameters can break the SQL syntax
545659
because `%` symbols are used to wildcard match within LIKE statements. This test
546660
just proves that's the case.
@@ -550,39 +664,80 @@ def test_inline_named_dont_break_sql(self):
550664
SELECT col_1 FROM base WHERE col_1 LIKE CONCAT(%(one)s, 'onite')
551665
"""
552666
params = {"one": "%(one)s"}
553-
with self.cursor(extra_params={"use_inline_params": True}) as cursor:
667+
with self.cursor(
668+
extra_params={"use_inline_params": True, **extra_params}
669+
) as cursor:
554670
result = cursor.execute(query, parameters=params).fetchone()
555671
print("hello")
556672

557-
def test_native_ordinals_dont_break_sql(self):
673+
@pytest.mark.parametrize(
674+
"extra_params",
675+
[
676+
{},
677+
{
678+
"use_sea": True,
679+
"use_cloud_fetch": False,
680+
"enable_query_result_lz4_compression": False,
681+
},
682+
],
683+
)
684+
def test_native_ordinals_dont_break_sql(self, extra_params):
558685
"""This test accompanies test_inline_ordinals_can_break_sql to prove that ordinal
559686
parameters work in native mode for the exact same query, if we use the right marker `?`
560687
"""
561688
query = "SELECT 'samsonite', ? WHERE 'samsonite' LIKE '%sonite'"
562689
params = ["luggage"]
563-
with self.cursor(extra_params={"use_inline_params": False}) as cursor:
690+
with self.cursor(
691+
extra_params={"use_inline_params": False, **extra_params}
692+
) as cursor:
564693
result = cursor.execute(query, parameters=params).fetchone()
565694

566695
assert result.samsonite == "samsonite"
567696
assert result.luggage == "luggage"
568697

569-
def test_inline_like_wildcard_breaks(self):
698+
@pytest.mark.parametrize(
699+
"extra_params",
700+
[
701+
{},
702+
{
703+
"use_sea": True,
704+
"use_cloud_fetch": False,
705+
"enable_query_result_lz4_compression": False,
706+
},
707+
],
708+
)
709+
def test_inline_like_wildcard_breaks(self, extra_params):
570710
"""One flaw with the ParameterEscaper is that it fails if a query contains
571711
a SQL LIKE wildcard %. This test proves that's the case.
572712
"""
573713
query = "SELECT 1 `col` WHERE 'foo' LIKE '%'"
574714
params = {"param": "bar"}
575-
with self.cursor(extra_params={"use_inline_params": True}) as cursor:
715+
with self.cursor(
716+
extra_params={"use_inline_params": True, **extra_params}
717+
) as cursor:
576718
with pytest.raises(ValueError, match="unsupported format character"):
577719
result = cursor.execute(query, parameters=params).fetchone()
578720

579-
def test_native_like_wildcard_works(self):
721+
@pytest.mark.parametrize(
722+
"extra_params",
723+
[
724+
{},
725+
{
726+
"use_sea": True,
727+
"use_cloud_fetch": False,
728+
"enable_query_result_lz4_compression": False,
729+
},
730+
],
731+
)
732+
def test_native_like_wildcard_works(self, extra_params):
580733
"""This is a mirror of test_inline_like_wildcard_breaks that proves that LIKE
581734
wildcards work under the native approach.
582735
"""
583736
query = "SELECT 1 `col` WHERE 'foo' LIKE '%'"
584737
params = {"param": "bar"}
585-
with self.cursor(extra_params={"use_inline_params": False}) as cursor:
738+
with self.cursor(
739+
extra_params={"use_inline_params": False, **extra_params}
740+
) as cursor:
586741
result = cursor.execute(query, parameters=params).fetchone()
587742

588743
assert result.col == 1

0 commit comments

Comments
 (0)