Skip to content

Commit b2b8a2a

Browse files
committed
More tests
1 parent 37c89b8 commit b2b8a2a

File tree

1 file changed

+97
-4
lines changed

1 file changed

+97
-4
lines changed

tests/e2e/test_parameterized_queries.py

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
from typing import Dict, List, Type, Union
66
from unittest.mock import patch
77

8+
import time
9+
import numpy as np
810
import pytest
911
import pytz
12+
from numpy.random.mtrand import Sequence
1013

1114
from databricks.sql.parameters.native import (
1215
BigIntegerParameter,
@@ -112,6 +115,7 @@ class TestParameterizedQueries(PySQLPytestTestCase):
112115
Primitive.NONE: "null_col",
113116
}
114117

118+
115119
def _get_inline_table_column(self, value):
116120
return self.inline_type_map[Primitive(value)]
117121

@@ -142,7 +146,9 @@ def inline_table(self, connection_details):
142146
date_col DATE,
143147
timestamp_col TIMESTAMP,
144148
array_col ARRAY<STRING>,
145-
map_col MAP<STRING, INT>
149+
map_col MAP<STRING, INT>,
150+
array_map_col ARRAY<MAP<STRING,INT>>,
151+
map_array_col MAP<INT,ARRAY<STRING>>
146152
) USING DELTA
147153
"""
148154

@@ -163,7 +169,7 @@ def patch_server_supports_native_params(self, supports_native_params: bool = Tru
163169
finally:
164170
pass
165171

166-
def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle):
172+
def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle, target_column):
167173
"""This INSERT, SELECT, DELETE dance is necessary because simply selecting
168174
```
169175
"SELECT %(param)s"
@@ -174,7 +180,6 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle):
174180
:paramstyle:
175181
This is a no-op but is included to make the test-code easier to read.
176182
"""
177-
target_column = self._get_inline_table_column(params.get("p"))
178183
INSERT_QUERY = f"INSERT INTO ___________________first.jprakash.pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)"
179184
SELECT_QUERY = f"SELECT {target_column} `col` FROM ___________________first.jprakash.pysql_e2e_inline_param_test_table LIMIT 1"
180185
DELETE_QUERY = "DELETE FROM ___________________first.jprakash.pysql_e2e_inline_param_test_table"
@@ -220,7 +225,7 @@ def _get_one_result(
220225
if approach == ParameterApproach.INLINE:
221226
# inline mode always uses ParamStyle.PYFORMAT
222227
# inline mode doesn't support positional parameters
223-
return self._inline_roundtrip(params, paramstyle=ParamStyle.PYFORMAT)
228+
return self._inline_roundtrip(params, paramstyle=ParamStyle.PYFORMAT,target_column=self._get_inline_table_column(params.get("p")))
224229
elif approach == ParameterApproach.NATIVE:
225230
# native mode can use either ParamStyle.NAMED or ParamStyle.PYFORMAT
226231
# native mode can use either ParameterStructure.NAMED or ParameterStructure.POSITIONAL
@@ -250,6 +255,57 @@ def _eq(self, actual, expected: Primitive):
250255

251256
return actual_parsed == expected_parsed
252257

258+
def _parse_to_common_type(self, value):
259+
"""
260+
Function to convert the :value passed into a common python datatype for comparison
261+
262+
Convertion fyi
263+
MAP Datatype on server is returned as a list of tuples
264+
Ex:
265+
{"a":1,"b":2} -> [("a",1),("b",2)]
266+
267+
ARRAY Datatype on server is returned as a numpy array
268+
Ex:
269+
["a","b","c"] -> np.array(["a","b","c"],dtype=object)
270+
271+
Primitive datatype on server is returned as a numpy primitive
272+
Ex:
273+
1 -> np.int64(1)
274+
2 -> np.int32(2)
275+
"""
276+
if value is None:
277+
return None
278+
elif isinstance(value, (Sequence,np.ndarray)) and not isinstance(value, (str, bytes)):
279+
return tuple(value)
280+
elif isinstance(value,dict):
281+
return tuple(value.items())
282+
elif isinstance(value, np.generic):
283+
return value.item()
284+
else:
285+
return value
286+
287+
def _recursive_compare(self,actual, expected):
288+
"""
289+
Function to compare the :actual and :expected values, recursively checks and ensures that all the data matches till the leaf level
290+
291+
Note: Complex datatype like MAP is not returned as a dictionary but as a list of tuples
292+
"""
293+
actual_parsed = self._parse_to_common_type(actual)
294+
expected_parsed = self._parse_to_common_type(expected)
295+
296+
# Check if types are the same
297+
if type(actual_parsed) != type(expected_parsed):
298+
return False
299+
300+
# Handle lists or tuples
301+
if isinstance(actual_parsed, (list, tuple)):
302+
if len(actual_parsed) != len(expected_parsed):
303+
return False
304+
return all(self._recursive_compare(o1, o2) for o1, o2 in zip(actual_parsed, expected_parsed))
305+
306+
return actual_parsed==expected_parsed
307+
308+
253309
@pytest.mark.parametrize("primitive", Primitive)
254310
@pytest.mark.parametrize(
255311
"approach,paramstyle,parameter_structure", approach_paramstyle_combinations
@@ -379,6 +435,43 @@ def test_readme_example(self):
379435
assert len(result) == 10
380436
assert result[0].p == "foo"
381437

438+
@pytest.mark.parametrize(
439+
"col_name,data",[
440+
("array_map_col", [{"a": 1, "b": 2}, {"c": 3, "d": 4}]),
441+
("map_array_col", {1: ["a", "b"], 2: ["c", "d"]}),
442+
])
443+
def test_inline_recursive_complex_type(self,col_name,data):
444+
params = {'p':data}
445+
result = self._inline_roundtrip(params=params,paramstyle=ParamStyle.PYFORMAT,target_column=col_name)
446+
assert self._recursive_compare(result.col, data)
447+
448+
449+
@pytest.mark.parametrize(
450+
"description,data",
451+
[
452+
("ARRAY<MAP<STRING,INT>>",[{"a": 1, "b": 2}, {"c": 3, "d": 4}]),
453+
("MAP<INT,ARRAY<STRING>>",{1: ["a", "b"], 2: ["c", "d"]}),
454+
("ARRAY<ARRAY<INT>>",[[1,2,3],[1,2,3]]),
455+
("ARRAY<ARRAY<ARRAY<INT>>>",[[[1,2,3],[1,2,3]],[[1,2,3],[1,2,3]]]),
456+
("MAP<STRING,MAP<STRING,STRING>>",{"a":{"b":"c","d":"e"},"f":{"g":"h","i":"j"}})
457+
],
458+
)
459+
@pytest.mark.parametrize(
460+
"paramstyle,parameter_structure",
461+
[
462+
(ParamStyle.NONE, ParameterStructure.POSITIONAL),
463+
(ParamStyle.PYFORMAT, ParameterStructure.NAMED),
464+
(ParamStyle.NAMED, ParameterStructure.NAMED),
465+
]
466+
)
467+
def test_native_recursive_complex_type(self,description,data,paramstyle,parameter_structure):
468+
if(paramstyle==ParamStyle.NONE):
469+
params=[data]
470+
else:
471+
params={'p':data}
472+
result = self._native_roundtrip(parameters=params,paramstyle=paramstyle,parameter_structure=parameter_structure)
473+
assert self._recursive_compare(result.col, data)
474+
382475

383476
class TestInlineParameterSyntax(PySQLPytestTestCase):
384477
"""The inline parameter approach uses pyformat markers"""

0 commit comments

Comments
 (0)