Skip to content

Commit 647ed39

Browse files
committed
fix unit tests
1 parent 0b1b05b commit 647ed39

File tree

2 files changed

+192
-155
lines changed

2 files changed

+192
-155
lines changed

src/databricks/sql/client.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,17 +1367,32 @@ def _convert_arrow_table(self, table):
13671367
pyarrow.float64(): pandas.Float64Dtype(),
13681368
pyarrow.string(): pandas.StringDtype(),
13691369
}
1370+
1371+
arrow_pandas_type_override = self.connection._arrow_pandas_type_override
1372+
if not isinstance(arrow_pandas_type_override, dict):
1373+
logger.debug(
1374+
"_arrow_pandas_type_override on connection was not a dict, using default type mapping"
1375+
)
1376+
arrow_pandas_type_override = {}
1377+
13701378
dtype_mapping = {
13711379
**DEFAULT_DTYPE_MAPPING,
1372-
**self.connection._arrow_pandas_type_override,
1380+
**arrow_pandas_type_override,
13731381
}
13741382

13751383
to_pandas_kwargs: dict[str, Any] = {
13761384
"types_mapper": dtype_mapping.get,
13771385
"date_as_object": True,
13781386
"timestamp_as_object": True,
13791387
}
1380-
to_pandas_kwargs.update(self.connection._arrow_to_pandas_kwargs)
1388+
1389+
arrow_to_pandas_kwargs = self.connection._arrow_to_pandas_kwargs
1390+
if isinstance(arrow_to_pandas_kwargs, dict):
1391+
to_pandas_kwargs.update(arrow_to_pandas_kwargs)
1392+
else:
1393+
logger.debug(
1394+
"_arrow_to_pandas_kwargs on connection was not a dict, using default arguments"
1395+
)
13811396

13821397
# Need to rename columns, as the to_pandas function cannot handle duplicate column names
13831398
table_renamed = table.rename_columns([str(c) for c in range(table.num_columns)])
Lines changed: 175 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -1,162 +1,184 @@
11
import pytest
2-
import pyarrow
2+
3+
try:
4+
import pyarrow as pa
5+
except ImportError:
6+
pa = None
37
import pandas
48
import datetime
5-
from unittest.mock import MagicMock, patch
9+
import unittest
10+
from unittest.mock import MagicMock
611

712
from databricks.sql.client import ResultSet, Connection, ExecuteResponse
813
from databricks.sql.types import Row
914
from databricks.sql.utils import ArrowQueue
1015

11-
12-
@pytest.fixture
13-
def mock_connection():
14-
conn = MagicMock(spec=Connection)
15-
conn.disable_pandas = False
16-
conn._arrow_pandas_type_override = {}
17-
conn._arrow_to_pandas_kwargs = {}
18-
if not hasattr(conn, "_arrow_to_pandas_kwargs"):
16+
@pytest.mark.skipif(pa is None, reason="PyArrow is not installed")
17+
class ArrowConversionTests(unittest.TestCase):
18+
@staticmethod
19+
def mock_connection_static():
20+
conn = MagicMock(spec=Connection)
21+
conn.disable_pandas = False
22+
conn._arrow_pandas_type_override = {}
1923
conn._arrow_to_pandas_kwargs = {}
20-
return conn
21-
22-
23-
@pytest.fixture
24-
def mock_thrift_backend(sample_arrow_table):
25-
tb = MagicMock()
26-
empty_arrays = [
27-
pyarrow.array([], type=field.type) for field in sample_arrow_table.schema
28-
]
29-
empty_table = pyarrow.Table.from_arrays(
30-
empty_arrays, schema=sample_arrow_table.schema
31-
)
32-
tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False)
33-
return tb
34-
35-
36-
@pytest.fixture
37-
def mock_raw_execute_response():
38-
er = MagicMock(spec=ExecuteResponse)
39-
er.description = [
40-
("col_int", "int", None, None, None, None, None),
41-
("col_str", "string", None, None, None, None, None),
42-
]
43-
er.arrow_schema_bytes = None
44-
er.arrow_queue = None
45-
er.has_more_rows = False
46-
er.lz4_compressed = False
47-
er.command_handle = MagicMock()
48-
er.status = MagicMock()
49-
er.has_been_closed_server_side = False
50-
er.is_staging_operation = False
51-
return er
52-
53-
54-
@pytest.fixture
55-
def sample_arrow_table():
56-
data = [
57-
pyarrow.array([1, 2, 3], type=pyarrow.int32()),
58-
pyarrow.array(["a", "b", "c"], type=pyarrow.string()),
59-
]
60-
schema = pyarrow.schema(
61-
[("col_int", pyarrow.int32()), ("col_str", pyarrow.string())]
62-
)
63-
return pyarrow.Table.from_arrays(data, schema=schema)
64-
65-
66-
def test_convert_arrow_table_default(
67-
mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table
68-
):
69-
mock_raw_execute_response.arrow_queue = ArrowQueue(
70-
sample_arrow_table, sample_arrow_table.num_rows
71-
)
72-
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
73-
result_one = rs.fetchone()
74-
assert isinstance(result_one, Row)
75-
assert result_one.col_int == 1
76-
assert result_one.col_str == "a"
77-
mock_raw_execute_response.arrow_queue = ArrowQueue(
78-
sample_arrow_table, sample_arrow_table.num_rows
79-
)
80-
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
81-
result_all = rs.fetchall()
82-
assert len(result_all) == 3
83-
assert isinstance(result_all[0], Row)
84-
assert result_all[0].col_int == 1
85-
assert result_all[1].col_str == "b"
86-
87-
88-
def test_convert_arrow_table_disable_pandas(
89-
mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table
90-
):
91-
mock_connection.disable_pandas = True
92-
mock_raw_execute_response.arrow_queue = ArrowQueue(
93-
sample_arrow_table, sample_arrow_table.num_rows
94-
)
95-
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
96-
result = rs.fetchall()
97-
assert len(result) == 3
98-
assert isinstance(result[0], Row)
99-
assert result[0].col_int == 1
100-
assert result[0].col_str == "a"
101-
assert isinstance(sample_arrow_table.column(0)[0].as_py(), int)
102-
assert isinstance(sample_arrow_table.column(1)[0].as_py(), str)
103-
104-
105-
def test_convert_arrow_table_type_override(
106-
mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table
107-
):
108-
mock_connection._arrow_pandas_type_override = {
109-
pyarrow.int32(): pandas.Float64Dtype()
110-
}
111-
mock_raw_execute_response.arrow_queue = ArrowQueue(
112-
sample_arrow_table, sample_arrow_table.num_rows
113-
)
114-
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
115-
result = rs.fetchall()
116-
assert len(result) == 3
117-
assert isinstance(result[0].col_int, float)
118-
assert result[0].col_int == 1.0
119-
assert result[0].col_str == "a"
120-
121-
122-
def test_convert_arrow_table_to_pandas_kwargs(
123-
mock_connection, mock_thrift_backend, mock_raw_execute_response
124-
):
125-
dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc)
126-
ts_array = pyarrow.array([dt_obj], type=pyarrow.timestamp("us", tz="UTC"))
127-
ts_schema = pyarrow.schema([("col_ts", pyarrow.timestamp("us", tz="UTC"))])
128-
ts_table = pyarrow.Table.from_arrays([ts_array], schema=ts_schema)
129-
130-
mock_raw_execute_response.description = [
131-
("col_ts", "timestamp", None, None, None, None, None)
132-
]
133-
mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows)
134-
135-
# Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row.
136-
mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True}
137-
rs_ts_true = ResultSet(
138-
mock_connection, mock_raw_execute_response, mock_thrift_backend
139-
)
140-
result_true = rs_ts_true.fetchall()
141-
assert len(result_true) == 1
142-
assert isinstance(result_true[0].col_ts, datetime.datetime)
143-
144-
# Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input.
145-
mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows)
146-
mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False}
147-
rs_ts_false = ResultSet(
148-
mock_connection, mock_raw_execute_response, mock_thrift_backend
149-
)
150-
result_false = rs_ts_false.fetchall()
151-
assert len(result_false) == 1
152-
assert isinstance(result_false[0].col_ts, pandas.Timestamp)
153-
154-
# Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default.
155-
mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows)
156-
mock_connection._arrow_to_pandas_kwargs = {}
157-
rs_ts_true = ResultSet(
158-
mock_connection, mock_raw_execute_response, mock_thrift_backend
159-
)
160-
result_true = rs_ts_true.fetchall()
161-
assert len(result_true) == 1
162-
assert isinstance(result_true[0].col_ts, datetime.datetime)
24+
return conn
25+
26+
@staticmethod
27+
def sample_arrow_table_static():
28+
data = [
29+
pa.array([1, 2, 3], type=pa.int32()),
30+
pa.array(["a", "b", "c"], type=pa.string()),
31+
]
32+
schema = pa.schema([("col_int", pa.int32()), ("col_str", pa.string())])
33+
return pa.Table.from_arrays(data, schema=schema)
34+
35+
@staticmethod
36+
def mock_thrift_backend_static():
37+
sample_table = ArrowConversionTests.sample_arrow_table_static()
38+
tb = MagicMock()
39+
empty_arrays = [pa.array([], type=field.type) for field in sample_table.schema]
40+
empty_table = pa.Table.from_arrays(empty_arrays, schema=sample_table.schema)
41+
tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False)
42+
return tb
43+
44+
@staticmethod
45+
def mock_raw_execute_response_static():
46+
er = MagicMock(spec=ExecuteResponse)
47+
er.description = [
48+
("col_int", "int", None, None, None, None, None),
49+
("col_str", "string", None, None, None, None, None),
50+
]
51+
er.arrow_schema_bytes = None
52+
er.arrow_queue = None
53+
er.has_more_rows = False
54+
er.lz4_compressed = False
55+
er.command_handle = MagicMock()
56+
er.status = MagicMock()
57+
er.has_been_closed_server_side = False
58+
er.is_staging_operation = False
59+
return er
60+
61+
def test_convert_arrow_table_default(self):
62+
mock_connection = ArrowConversionTests.mock_connection_static()
63+
sample_arrow_table = ArrowConversionTests.sample_arrow_table_static()
64+
mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static()
65+
mock_raw_execute_response = (
66+
ArrowConversionTests.mock_raw_execute_response_static()
67+
)
68+
69+
mock_raw_execute_response.arrow_queue = ArrowQueue(
70+
sample_arrow_table, sample_arrow_table.num_rows
71+
)
72+
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
73+
result_one = rs.fetchone()
74+
self.assertIsInstance(result_one, Row)
75+
self.assertEqual(result_one.col_int, 1)
76+
self.assertEqual(result_one.col_str, "a")
77+
78+
mock_raw_execute_response.arrow_queue = ArrowQueue(
79+
sample_arrow_table, sample_arrow_table.num_rows
80+
)
81+
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
82+
result_all = rs.fetchall()
83+
self.assertEqual(len(result_all), 3)
84+
self.assertIsInstance(result_all[0], Row)
85+
self.assertEqual(result_all[0].col_int, 1)
86+
self.assertEqual(result_all[1].col_str, "b")
87+
88+
def test_convert_arrow_table_disable_pandas(self):
89+
mock_connection = ArrowConversionTests.mock_connection_static()
90+
sample_arrow_table = ArrowConversionTests.sample_arrow_table_static()
91+
mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static()
92+
mock_raw_execute_response = (
93+
ArrowConversionTests.mock_raw_execute_response_static()
94+
)
95+
96+
mock_connection.disable_pandas = True
97+
mock_raw_execute_response.arrow_queue = ArrowQueue(
98+
sample_arrow_table, sample_arrow_table.num_rows
99+
)
100+
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
101+
result = rs.fetchall()
102+
self.assertEqual(len(result), 3)
103+
self.assertIsInstance(result[0], Row)
104+
self.assertEqual(result[0].col_int, 1)
105+
self.assertEqual(result[0].col_str, "a")
106+
self.assertIsInstance(sample_arrow_table.column(0)[0].as_py(), int)
107+
self.assertIsInstance(sample_arrow_table.column(1)[0].as_py(), str)
108+
109+
def test_convert_arrow_table_type_override(self):
110+
mock_connection = ArrowConversionTests.mock_connection_static()
111+
sample_arrow_table = ArrowConversionTests.sample_arrow_table_static()
112+
mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static()
113+
mock_raw_execute_response = (
114+
ArrowConversionTests.mock_raw_execute_response_static()
115+
)
116+
117+
mock_connection._arrow_pandas_type_override = {
118+
pa.int32(): pandas.Float64Dtype()
119+
}
120+
mock_raw_execute_response.arrow_queue = ArrowQueue(
121+
sample_arrow_table, sample_arrow_table.num_rows
122+
)
123+
rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend)
124+
result = rs.fetchall()
125+
self.assertEqual(len(result), 3)
126+
self.assertIsInstance(result[0].col_int, float)
127+
self.assertEqual(result[0].col_int, 1.0)
128+
self.assertEqual(result[0].col_str, "a")
129+
130+
def test_convert_arrow_table_to_pandas_kwargs(self):
131+
mock_connection = ArrowConversionTests.mock_connection_static()
132+
mock_thrift_backend = (
133+
ArrowConversionTests.mock_thrift_backend_static()
134+
) # Does not use sample_arrow_table
135+
mock_raw_execute_response = (
136+
ArrowConversionTests.mock_raw_execute_response_static()
137+
)
138+
139+
dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc)
140+
ts_array = pa.array([dt_obj], type=pa.timestamp("us", tz="UTC"))
141+
ts_schema = pa.schema([("col_ts", pa.timestamp("us", tz="UTC"))])
142+
ts_table = pa.Table.from_arrays([ts_array], schema=ts_schema)
143+
144+
mock_raw_execute_response.description = [
145+
("col_ts", "timestamp", None, None, None, None, None)
146+
]
147+
mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows)
148+
149+
# Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row.
150+
mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True}
151+
rs_ts_true = ResultSet(
152+
mock_connection, mock_raw_execute_response, mock_thrift_backend
153+
)
154+
result_true = rs_ts_true.fetchall()
155+
self.assertEqual(len(result_true), 1)
156+
self.assertIsInstance(result_true[0].col_ts, datetime.datetime)
157+
158+
# Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input.
159+
mock_raw_execute_response.arrow_queue = ArrowQueue(
160+
ts_table, ts_table.num_rows
161+
) # Reset queue
162+
mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False}
163+
rs_ts_false = ResultSet(
164+
mock_connection, mock_raw_execute_response, mock_thrift_backend
165+
)
166+
result_false = rs_ts_false.fetchall()
167+
self.assertEqual(len(result_false), 1)
168+
self.assertIsInstance(result_false[0].col_ts, pandas.Timestamp)
169+
170+
# Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default.
171+
mock_raw_execute_response.arrow_queue = ArrowQueue(
172+
ts_table, ts_table.num_rows
173+
) # Reset queue
174+
mock_connection._arrow_to_pandas_kwargs = {}
175+
rs_ts_default = ResultSet(
176+
mock_connection, mock_raw_execute_response, mock_thrift_backend
177+
)
178+
result_default = rs_ts_default.fetchall()
179+
self.assertEqual(len(result_default), 1)
180+
self.assertIsInstance(result_default[0].col_ts, datetime.datetime)
181+
182+
183+
if __name__ == "__main__":
184+
unittest.main()

0 commit comments

Comments
 (0)