Skip to content

Commit 4c36d99

Browse files
committed
Added unit tests
1 parent b2b8a2a commit 4c36d99

File tree

6 files changed

+47
-9
lines changed

6 files changed

+47
-9
lines changed

src/databricks/sql/parameters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@
1212
TimestampNTZParameter,
1313
TinyIntParameter,
1414
DecimalParameter,
15+
MapParameter,
16+
ArrayParameter,
1517
)

src/databricks/sql/parameters/native.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def __init__(self, value: Sequence[Any], name: Optional[str] = None):
467467
self.name = name
468468
self.value = [dbsql_parameter_from_primitive(val) for val in value]
469469

470-
def as_tspark_param(self, named: bool) -> TSparkParameter:
470+
def as_tspark_param(self, named: bool = None) -> TSparkParameter:
471471
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
472472

473473
tsp = TSparkParameter(type=self._cast_expr())
@@ -516,7 +516,7 @@ def __init__(self, value: dict, name: Optional[str] = None):
516516
for item in (key, val)
517517
]
518518

519-
def as_tspark_param(self, named: bool) -> TSparkParameter:
519+
def as_tspark_param(self, named: bool = None) -> TSparkParameter:
520520
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
521521

522522
tsp = TSparkParameter(type=self._cast_expr())

src/databricks/sql/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def escape_mapping(self, item):
473473
def escape_datetime(self, item, format, cutoff=0):
474474
dt_str = item.strftime(format)
475475
formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str
476-
return "'{}'".format(formatted)
476+
return "'{}'".format(formatted.strip())
477477

478478
def escape_decimal(self, item):
479479
return str(item)

tests/unit/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def test_execute_parameter_passthrough(self):
455455
("SELECT %(x)s", "SELECT NULL", {"x": None}),
456456
("SELECT %(int_value)d", "SELECT 48", {"int_value": 48}),
457457
("SELECT %(float_value).2f", "SELECT 48.20", {"float_value": 48.2}),
458-
("SELECT %(iter)s", "SELECT (1,2,3,4,5)", {"iter": [1, 2, 3, 4, 5]}),
458+
("SELECT %(iter)s", "SELECT ARRAY(1,2,3,4,5)", {"iter": [1, 2, 3, 4, 5]}),
459459
(
460460
"SELECT %(datetime)s",
461461
"SELECT '2022-02-01 10:23:00.000000'",

tests/unit/test_param_escaper.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,23 +120,37 @@ def test_escape_date(self):
120120
assert pe.escape_datetime(INPUT, FORMAT) == OUTPUT
121121

122122
def test_escape_sequence_integer(self):
123-
assert pe.escape_sequence([1, 2, 3, 4]) == "(1,2,3,4)"
123+
assert pe.escape_sequence([1, 2, 3, 4]) == "ARRAY(1,2,3,4)"
124124

125125
def test_escape_sequence_float(self):
126-
assert pe.escape_sequence([1.1, 2.2, 3.3, 4.4]) == "(1.1,2.2,3.3,4.4)"
126+
assert pe.escape_sequence([1.1, 2.2, 3.3, 4.4]) == "ARRAY(1.1,2.2,3.3,4.4)"
127127

128128
def test_escape_sequence_string(self):
129129
assert (
130130
pe.escape_sequence(["his", "name", "was", "robert", "palmer"])
131-
== "('his','name','was','robert','palmer')"
131+
== "ARRAY('his','name','was','robert','palmer')"
132132
)
133133

134134
def test_escape_sequence_sequence_of_strings(self):
135-
# This is not valid SQL.
136135
INPUT = [["his", "name"], ["was", "robert"], ["palmer"]]
137-
OUTPUT = "(('his','name'),('was','robert'),('palmer'))"
136+
OUTPUT = "ARRAY(ARRAY('his','name'),ARRAY('was','robert'),ARRAY('palmer'))"
138137

139138
assert pe.escape_sequence(INPUT) == OUTPUT
139+
140+
def test_escape_map_string_int(self):
141+
INPUT = {"a": 1, "b": 2}
142+
OUTPUT = "MAP('a',1,'b',2)"
143+
assert pe.escape_mapping(INPUT) == OUTPUT
144+
145+
def test_escape_map_string_sequence_of_floats(self):
146+
INPUT = {"a": [1.1, 2.2, 3.3], "b": [4.4, 5.5, 6.6]}
147+
OUTPUT = "MAP('a',ARRAY(1.1,2.2,3.3),'b',ARRAY(4.4,5.5,6.6))"
148+
assert pe.escape_mapping(INPUT) == OUTPUT
149+
150+
def test_escape_sequence_of_map_int_string(self):
151+
INPUT = [{1: "a", 2: "foo"}, {3: "b", 4: "bar"}]
152+
OUTPUT = "ARRAY(MAP(1,'a',2,'foo'),MAP(3,'b',4,'bar'))"
153+
assert pe.escape_sequence(INPUT) == OUTPUT
140154

141155

142156
class TestFullQueryEscaping(object):

tests/unit/test_parameters.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,14 @@
2121
TimestampParameter,
2222
TinyIntParameter,
2323
VoidParameter,
24+
MapParameter,
25+
ArrayParameter,
2426
)
2527
from databricks.sql.parameters.native import (
2628
TDbsqlParameter,
29+
TSparkParameter,
2730
TSparkParameterValue,
31+
TSparkParameterValueArg,
2832
dbsql_parameter_from_primitive,
2933
)
3034
from databricks.sql.thrift_api.TCLIService import ttypes
@@ -112,6 +116,8 @@ class Primitive(Enum):
112116
DOUBLE = 3.14
113117
FLOAT = 3.15
114118
SMALLINT = 51
119+
ARRAY = [1, 2, 3]
120+
MAP = {"a": 1, "b": 2}
115121

116122

117123
class TestDbsqlParameter:
@@ -131,6 +137,8 @@ class TestDbsqlParameter:
131137
(TimestampParameter, Primitive.TIMESTAMP, "TIMESTAMP"),
132138
(TimestampNTZParameter, Primitive.TIMESTAMP, "TIMESTAMP_NTZ"),
133139
(TinyIntParameter, Primitive.INT, "TINYINT"),
140+
(MapParameter, Primitive.MAP, "MAP"),
141+
(ArrayParameter, Primitive.ARRAY, "ARRAY"),
134142
),
135143
)
136144
def test_cast_expression(
@@ -165,6 +173,18 @@ def test_tspark_param_value(self, t: TDbsqlParameter, prim):
165173
assert output == None
166174
else:
167175
assert output == TSparkParameterValue(stringValue=str(prim.value))
176+
177+
@pytest.mark.parametrize(
178+
"base_type,input,expected_output",[
179+
(ArrayParameter, [1,2,3], TSparkParameter(ordinal=True, name=None, type='ARRAY', value=None, arguments=[TSparkParameterValueArg(type='INT', value='1', arguments=None), TSparkParameterValueArg(type='INT', value='2', arguments=None), TSparkParameterValueArg(type='INT', value='3', arguments=None)])),
180+
(MapParameter, {"a": 1, "b": 2}, TSparkParameter(ordinal=True, name=None, type='MAP', value=None, arguments=[TSparkParameterValueArg(type='STRING', value='a', arguments=None), TSparkParameterValueArg(type='INT', value='1', arguments=None), TSparkParameterValueArg(type='STRING', value='b', arguments=None), TSparkParameterValueArg(type='INT', value='2', arguments=None)])),
181+
(ArrayParameter,[{"a":1,"b":2},{"c":3,"d":4}], TSparkParameter(ordinal=True, name=None, type='ARRAY', value=None, arguments=[TSparkParameterValueArg(type='MAP', value=None, arguments=[TSparkParameterValueArg(type='STRING', value='a', arguments=None), TSparkParameterValueArg(type='INT', value='1', arguments=None), TSparkParameterValueArg(type='STRING', value='b', arguments=None), TSparkParameterValueArg(type='INT', value='2', arguments=None)]), TSparkParameterValueArg(type='MAP', value=None, arguments=[TSparkParameterValueArg(type='STRING', value='c', arguments=None), TSparkParameterValueArg(type='INT', value='3', arguments=None), TSparkParameterValueArg(type='STRING', value='d', arguments=None), TSparkParameterValueArg(type='INT', value='4', arguments=None)])])),
182+
]
183+
)
184+
def test_complex_type_tspark_param(self,base_type,input,expected_output):
185+
p = base_type(input)
186+
tsp = p.as_tspark_param()
187+
assert tsp == expected_output
168188

169189
def test_tspark_param_named(self):
170190
p = dbsql_parameter_from_primitive(Primitive.INT.value, name="p")
@@ -192,6 +212,8 @@ def test_tspark_param_ordinal(self):
192212
(FloatParameter, Primitive.FLOAT),
193213
(VoidParameter, Primitive.NONE),
194214
(TimestampParameter, Primitive.TIMESTAMP),
215+
(MapParameter, Primitive.MAP),
216+
(ArrayParameter, Primitive.ARRAY),
195217
),
196218
)
197219
def test_inference(self, _type: TDbsqlParameter, prim: Primitive):

0 commit comments

Comments
 (0)