Skip to content

Commit 09f4d18

Browse files
committed
refractor
1 parent 4c36d99 commit 09f4d18

File tree

6 files changed

+168
-46
lines changed

6 files changed

+168
-46
lines changed

src/databricks/sql/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,7 @@ def execute(
789789
790790
:returns self
791791
"""
792-
792+
793793
param_approach = self._determine_parameter_approach(parameters)
794794
if param_approach == ParameterApproach.NONE:
795795
prepared_params = NO_NATIVE_PARAMS
@@ -808,7 +808,7 @@ def execute(
808808
prepared_operation, prepared_params = self._prepare_native_parameters(
809809
transformed_operation, normalized_parameters, param_structure
810810
)
811-
811+
812812
self._check_not_closed()
813813
self._close_and_clear_active_result_set()
814814
execute_response = self.thrift_backend.execute_command(

tests/e2e/test_complex_types.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,14 @@ def table_fixture(self, connection_details):
4545

4646
@pytest.mark.parametrize(
4747
"field,expected_type",
48-
[("array_col", ndarray), ("map_col", list), ("struct_col", dict), ("array_array_col", ndarray), ("array_map_col", ndarray), ("map_array_col", list)],
48+
[
49+
("array_col", ndarray),
50+
("map_col", list),
51+
("struct_col", dict),
52+
("array_array_col", ndarray),
53+
("array_map_col", ndarray),
54+
("map_array_col", list),
55+
],
4956
)
5057
def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture):
5158
"""Confirms the return types of a complex type field when reading as arrow"""
@@ -54,10 +61,20 @@ def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture):
5461
result = cursor.execute(
5562
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
5663
).fetchone()
57-
64+
5865
assert isinstance(result[field], expected_type)
5966

60-
@pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col"), ("array_array_col"), ("array_map_col"), ("map_array_col")])
67+
@pytest.mark.parametrize(
68+
"field",
69+
[
70+
("array_col"),
71+
("map_col"),
72+
("struct_col"),
73+
("array_array_col"),
74+
("array_map_col"),
75+
("map_array_col"),
76+
],
77+
)
6178
def test_read_complex_types_as_string(self, field, table_fixture):
6279
"""Confirms the return type of a complex type that is returned as a string"""
6380
with self.cursor(

tests/e2e/test_parameterized_queries.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ class TestParameterizedQueries(PySQLPytestTestCase):
115115
Primitive.NONE: "null_col",
116116
}
117117

118-
119118
def _get_inline_table_column(self, value):
120119
return self.inline_type_map[Primitive(value)]
121120

@@ -225,7 +224,11 @@ def _get_one_result(
225224
if approach == ParameterApproach.INLINE:
226225
# inline mode always uses ParamStyle.PYFORMAT
227226
# inline mode doesn't support positional parameters
228-
return self._inline_roundtrip(params, paramstyle=ParamStyle.PYFORMAT,target_column=self._get_inline_table_column(params.get("p")))
227+
return self._inline_roundtrip(
228+
params,
229+
paramstyle=ParamStyle.PYFORMAT,
230+
target_column=self._get_inline_table_column(params.get("p")),
231+
)
229232
elif approach == ParameterApproach.NATIVE:
230233
# native mode can use either ParamStyle.NAMED or ParamStyle.PYFORMAT
231234
# native mode can use either ParameterStructure.NAMED or ParameterStructure.POSITIONAL
@@ -252,7 +255,7 @@ def _eq(self, actual, expected: Primitive):
252255
actual_parsed = actual.tolist()
253256
elif expected == Primitive.MAPS:
254257
expected_parsed = list(expected.value.items())
255-
258+
256259
return actual_parsed == expected_parsed
257260

258261
def _parse_to_common_type(self, value):
@@ -263,28 +266,30 @@ def _parse_to_common_type(self, value):
263266
MAP Datatype on server is returned as a list of tuples
264267
Ex:
265268
{"a":1,"b":2} -> [("a",1),("b",2)]
266-
269+
267270
ARRAY Datatype on server is returned as a numpy array
268271
Ex:
269272
["a","b","c"] -> np.array(["a","b","c"],dtype=object)
270-
273+
271274
Primitive datatype on server is returned as a numpy primitive
272275
Ex:
273276
1 -> np.int64(1)
274277
2 -> np.int32(2)
275278
"""
276279
if value is None:
277280
return None
278-
elif isinstance(value, (Sequence,np.ndarray)) and not isinstance(value, (str, bytes)):
281+
elif isinstance(value, (Sequence, np.ndarray)) and not isinstance(
282+
value, (str, bytes)
283+
):
279284
return tuple(value)
280-
elif isinstance(value,dict):
285+
elif isinstance(value, dict):
281286
return tuple(value.items())
282287
elif isinstance(value, np.generic):
283288
return value.item()
284289
else:
285290
return value
286291

287-
def _recursive_compare(self,actual, expected):
292+
def _recursive_compare(self, actual, expected):
288293
"""
289294
Function to compare the :actual and :expected values, recursively checks and ensures that all the data matches till the leaf level
290295
@@ -301,10 +306,12 @@ def _recursive_compare(self,actual, expected):
301306
if isinstance(actual_parsed, (list, tuple)):
302307
if len(actual_parsed) != len(expected_parsed):
303308
return False
304-
return all(self._recursive_compare(o1, o2) for o1, o2 in zip(actual_parsed, expected_parsed))
305-
306-
return actual_parsed==expected_parsed
309+
return all(
310+
self._recursive_compare(o1, o2)
311+
for o1, o2 in zip(actual_parsed, expected_parsed)
312+
)
307313

314+
return actual_parsed == expected_parsed
308315

309316
@pytest.mark.parametrize("primitive", Primitive)
310317
@pytest.mark.parametrize(
@@ -436,24 +443,33 @@ def test_readme_example(self):
436443
assert result[0].p == "foo"
437444

438445
@pytest.mark.parametrize(
439-
"col_name,data",[
440-
("array_map_col", [{"a": 1, "b": 2}, {"c": 3, "d": 4}]),
441-
("map_array_col", {1: ["a", "b"], 2: ["c", "d"]}),
442-
])
443-
def test_inline_recursive_complex_type(self,col_name,data):
444-
params = {'p':data}
445-
result = self._inline_roundtrip(params=params,paramstyle=ParamStyle.PYFORMAT,target_column=col_name)
446+
"col_name,data",
447+
[
448+
("array_map_col", [{"a": 1, "b": 2}, {"c": 3, "d": 4}]),
449+
("map_array_col", {1: ["a", "b"], 2: ["c", "d"]}),
450+
],
451+
)
452+
def test_inline_recursive_complex_type(self, col_name, data):
453+
params = {"p": data}
454+
result = self._inline_roundtrip(
455+
params=params, paramstyle=ParamStyle.PYFORMAT, target_column=col_name
456+
)
446457
assert self._recursive_compare(result.col, data)
447-
448-
458+
449459
@pytest.mark.parametrize(
450460
"description,data",
451461
[
452-
("ARRAY<MAP<STRING,INT>>",[{"a": 1, "b": 2}, {"c": 3, "d": 4}]),
453-
("MAP<INT,ARRAY<STRING>>",{1: ["a", "b"], 2: ["c", "d"]}),
454-
("ARRAY<ARRAY<INT>>",[[1,2,3],[1,2,3]]),
455-
("ARRAY<ARRAY<ARRAY<INT>>>",[[[1,2,3],[1,2,3]],[[1,2,3],[1,2,3]]]),
456-
("MAP<STRING,MAP<STRING,STRING>>",{"a":{"b":"c","d":"e"},"f":{"g":"h","i":"j"}})
462+
("ARRAY<MAP<STRING,INT>>", [{"a": 1, "b": 2}, {"c": 3, "d": 4}]),
463+
("MAP<INT,ARRAY<STRING>>", {1: ["a", "b"], 2: ["c", "d"]}),
464+
("ARRAY<ARRAY<INT>>", [[1, 2, 3], [1, 2, 3]]),
465+
(
466+
"ARRAY<ARRAY<ARRAY<INT>>>",
467+
[[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]],
468+
),
469+
(
470+
"MAP<STRING,MAP<STRING,STRING>>",
471+
{"a": {"b": "c", "d": "e"}, "f": {"g": "h", "i": "j"}},
472+
),
457473
],
458474
)
459475
@pytest.mark.parametrize(
@@ -462,14 +478,20 @@ def test_inline_recursive_complex_type(self,col_name,data):
462478
(ParamStyle.NONE, ParameterStructure.POSITIONAL),
463479
(ParamStyle.PYFORMAT, ParameterStructure.NAMED),
464480
(ParamStyle.NAMED, ParameterStructure.NAMED),
465-
]
481+
],
466482
)
467-
def test_native_recursive_complex_type(self,description,data,paramstyle,parameter_structure):
468-
if(paramstyle==ParamStyle.NONE):
469-
params=[data]
483+
def test_native_recursive_complex_type(
484+
self, description, data, paramstyle, parameter_structure
485+
):
486+
if paramstyle == ParamStyle.NONE:
487+
params = [data]
470488
else:
471-
params={'p':data}
472-
result = self._native_roundtrip(parameters=params,paramstyle=paramstyle,parameter_structure=parameter_structure)
489+
params = {"p": data}
490+
result = self._native_roundtrip(
491+
parameters=params,
492+
paramstyle=paramstyle,
493+
parameter_structure=parameter_structure,
494+
)
473495
assert self._recursive_compare(result.col, data)
474496

475497

tests/unit/test_param_escaper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_escape_sequence_sequence_of_strings(self):
136136
OUTPUT = "ARRAY(ARRAY('his','name'),ARRAY('was','robert'),ARRAY('palmer'))"
137137

138138
assert pe.escape_sequence(INPUT) == OUTPUT
139-
139+
140140
def test_escape_map_string_int(self):
141141
INPUT = {"a": 1, "b": 2}
142142
OUTPUT = "MAP('a',1,'b',2)"
@@ -146,7 +146,7 @@ def test_escape_map_string_sequence_of_floats(self):
146146
INPUT = {"a": [1.1, 2.2, 3.3], "b": [4.4, 5.5, 6.6]}
147147
OUTPUT = "MAP('a',ARRAY(1.1,2.2,3.3),'b',ARRAY(4.4,5.5,6.6))"
148148
assert pe.escape_mapping(INPUT) == OUTPUT
149-
149+
150150
def test_escape_sequence_of_map_int_string(self):
151151
INPUT = [{1: "a", 2: "foo"}, {3: "b", 4: "bar"}]
152152
OUTPUT = "ARRAY(MAP(1,'a',2,'foo'),MAP(3,'b',4,'bar'))"

tests/unit/test_parameters.py

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,96 @@ def test_tspark_param_value(self, t: TDbsqlParameter, prim):
173173
assert output == None
174174
else:
175175
assert output == TSparkParameterValue(stringValue=str(prim.value))
176-
176+
177177
@pytest.mark.parametrize(
178-
"base_type,input,expected_output",[
179-
(ArrayParameter, [1,2,3], TSparkParameter(ordinal=True, name=None, type='ARRAY', value=None, arguments=[TSparkParameterValueArg(type='INT', value='1', arguments=None), TSparkParameterValueArg(type='INT', value='2', arguments=None), TSparkParameterValueArg(type='INT', value='3', arguments=None)])),
180-
(MapParameter, {"a": 1, "b": 2}, TSparkParameter(ordinal=True, name=None, type='MAP', value=None, arguments=[TSparkParameterValueArg(type='STRING', value='a', arguments=None), TSparkParameterValueArg(type='INT', value='1', arguments=None), TSparkParameterValueArg(type='STRING', value='b', arguments=None), TSparkParameterValueArg(type='INT', value='2', arguments=None)])),
181-
(ArrayParameter,[{"a":1,"b":2},{"c":3,"d":4}], TSparkParameter(ordinal=True, name=None, type='ARRAY', value=None, arguments=[TSparkParameterValueArg(type='MAP', value=None, arguments=[TSparkParameterValueArg(type='STRING', value='a', arguments=None), TSparkParameterValueArg(type='INT', value='1', arguments=None), TSparkParameterValueArg(type='STRING', value='b', arguments=None), TSparkParameterValueArg(type='INT', value='2', arguments=None)]), TSparkParameterValueArg(type='MAP', value=None, arguments=[TSparkParameterValueArg(type='STRING', value='c', arguments=None), TSparkParameterValueArg(type='INT', value='3', arguments=None), TSparkParameterValueArg(type='STRING', value='d', arguments=None), TSparkParameterValueArg(type='INT', value='4', arguments=None)])])),
182-
]
178+
"base_type,input,expected_output",
179+
[
180+
(
181+
ArrayParameter,
182+
[1, 2, 3],
183+
TSparkParameter(
184+
ordinal=True,
185+
name=None,
186+
type="ARRAY",
187+
value=None,
188+
arguments=[
189+
TSparkParameterValueArg(type="INT", value="1", arguments=None),
190+
TSparkParameterValueArg(type="INT", value="2", arguments=None),
191+
TSparkParameterValueArg(type="INT", value="3", arguments=None),
192+
],
193+
),
194+
),
195+
(
196+
MapParameter,
197+
{"a": 1, "b": 2},
198+
TSparkParameter(
199+
ordinal=True,
200+
name=None,
201+
type="MAP",
202+
value=None,
203+
arguments=[
204+
TSparkParameterValueArg(
205+
type="STRING", value="a", arguments=None
206+
),
207+
TSparkParameterValueArg(type="INT", value="1", arguments=None),
208+
TSparkParameterValueArg(
209+
type="STRING", value="b", arguments=None
210+
),
211+
TSparkParameterValueArg(type="INT", value="2", arguments=None),
212+
],
213+
),
214+
),
215+
(
216+
ArrayParameter,
217+
[{"a": 1, "b": 2}, {"c": 3, "d": 4}],
218+
TSparkParameter(
219+
ordinal=True,
220+
name=None,
221+
type="ARRAY",
222+
value=None,
223+
arguments=[
224+
TSparkParameterValueArg(
225+
type="MAP",
226+
value=None,
227+
arguments=[
228+
TSparkParameterValueArg(
229+
type="STRING", value="a", arguments=None
230+
),
231+
TSparkParameterValueArg(
232+
type="INT", value="1", arguments=None
233+
),
234+
TSparkParameterValueArg(
235+
type="STRING", value="b", arguments=None
236+
),
237+
TSparkParameterValueArg(
238+
type="INT", value="2", arguments=None
239+
),
240+
],
241+
),
242+
TSparkParameterValueArg(
243+
type="MAP",
244+
value=None,
245+
arguments=[
246+
TSparkParameterValueArg(
247+
type="STRING", value="c", arguments=None
248+
),
249+
TSparkParameterValueArg(
250+
type="INT", value="3", arguments=None
251+
),
252+
TSparkParameterValueArg(
253+
type="STRING", value="d", arguments=None
254+
),
255+
TSparkParameterValueArg(
256+
type="INT", value="4", arguments=None
257+
),
258+
],
259+
),
260+
],
261+
),
262+
),
263+
],
183264
)
184-
def test_complex_type_tspark_param(self,base_type,input,expected_output):
265+
def test_complex_type_tspark_param(self, base_type, input, expected_output):
185266
p = base_type(input)
186267
tsp = p.as_tspark_param()
187268
assert tsp == expected_output

tests/unit/test_thrift_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ def test_make_request_checks_thrift_status_code(self):
8686

8787
def _make_type_desc(self, type):
8888
return ttypes.TTypeDesc(
89-
types=[ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type))]
89+
types=[
90+
ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type))
91+
]
9092
)
9193

9294
def _make_fake_thrift_backend(self):

0 commit comments

Comments
 (0)