|
1 | 1 | import pytest |
2 | | -import pyarrow |
| 2 | + |
| 3 | +try: |
| 4 | + import pyarrow as pa |
| 5 | +except ImportError: |
| 6 | + pa = None |
3 | 7 | import pandas |
4 | 8 | import datetime |
5 | | -from unittest.mock import MagicMock, patch |
| 9 | +import unittest |
| 10 | +from unittest.mock import MagicMock |
6 | 11 |
|
7 | 12 | from databricks.sql.client import ResultSet, Connection, ExecuteResponse |
8 | 13 | from databricks.sql.types import Row |
9 | 14 | from databricks.sql.utils import ArrowQueue |
10 | 15 |
|
11 | | - |
12 | | -@pytest.fixture |
13 | | -def mock_connection(): |
14 | | - conn = MagicMock(spec=Connection) |
15 | | - conn.disable_pandas = False |
16 | | - conn._arrow_pandas_type_override = {} |
17 | | - conn._arrow_to_pandas_kwargs = {} |
18 | | - if not hasattr(conn, "_arrow_to_pandas_kwargs"): |
| 16 | +@pytest.mark.skipif(pa is None, reason="PyArrow is not installed") |
| 17 | +class ArrowConversionTests(unittest.TestCase): |
| 18 | + @staticmethod |
| 19 | + def mock_connection_static(): |
| 20 | + conn = MagicMock(spec=Connection) |
| 21 | + conn.disable_pandas = False |
| 22 | + conn._arrow_pandas_type_override = {} |
19 | 23 | conn._arrow_to_pandas_kwargs = {} |
20 | | - return conn |
21 | | - |
22 | | - |
23 | | -@pytest.fixture |
24 | | -def mock_thrift_backend(sample_arrow_table): |
25 | | - tb = MagicMock() |
26 | | - empty_arrays = [ |
27 | | - pyarrow.array([], type=field.type) for field in sample_arrow_table.schema |
28 | | - ] |
29 | | - empty_table = pyarrow.Table.from_arrays( |
30 | | - empty_arrays, schema=sample_arrow_table.schema |
31 | | - ) |
32 | | - tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False) |
33 | | - return tb |
34 | | - |
35 | | - |
36 | | -@pytest.fixture |
37 | | -def mock_raw_execute_response(): |
38 | | - er = MagicMock(spec=ExecuteResponse) |
39 | | - er.description = [ |
40 | | - ("col_int", "int", None, None, None, None, None), |
41 | | - ("col_str", "string", None, None, None, None, None), |
42 | | - ] |
43 | | - er.arrow_schema_bytes = None |
44 | | - er.arrow_queue = None |
45 | | - er.has_more_rows = False |
46 | | - er.lz4_compressed = False |
47 | | - er.command_handle = MagicMock() |
48 | | - er.status = MagicMock() |
49 | | - er.has_been_closed_server_side = False |
50 | | - er.is_staging_operation = False |
51 | | - return er |
52 | | - |
53 | | - |
54 | | -@pytest.fixture |
55 | | -def sample_arrow_table(): |
56 | | - data = [ |
57 | | - pyarrow.array([1, 2, 3], type=pyarrow.int32()), |
58 | | - pyarrow.array(["a", "b", "c"], type=pyarrow.string()), |
59 | | - ] |
60 | | - schema = pyarrow.schema( |
61 | | - [("col_int", pyarrow.int32()), ("col_str", pyarrow.string())] |
62 | | - ) |
63 | | - return pyarrow.Table.from_arrays(data, schema=schema) |
64 | | - |
65 | | - |
66 | | -def test_convert_arrow_table_default( |
67 | | - mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table |
68 | | -): |
69 | | - mock_raw_execute_response.arrow_queue = ArrowQueue( |
70 | | - sample_arrow_table, sample_arrow_table.num_rows |
71 | | - ) |
72 | | - rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) |
73 | | - result_one = rs.fetchone() |
74 | | - assert isinstance(result_one, Row) |
75 | | - assert result_one.col_int == 1 |
76 | | - assert result_one.col_str == "a" |
77 | | - mock_raw_execute_response.arrow_queue = ArrowQueue( |
78 | | - sample_arrow_table, sample_arrow_table.num_rows |
79 | | - ) |
80 | | - rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) |
81 | | - result_all = rs.fetchall() |
82 | | - assert len(result_all) == 3 |
83 | | - assert isinstance(result_all[0], Row) |
84 | | - assert result_all[0].col_int == 1 |
85 | | - assert result_all[1].col_str == "b" |
86 | | - |
87 | | - |
88 | | -def test_convert_arrow_table_disable_pandas( |
89 | | - mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table |
90 | | -): |
91 | | - mock_connection.disable_pandas = True |
92 | | - mock_raw_execute_response.arrow_queue = ArrowQueue( |
93 | | - sample_arrow_table, sample_arrow_table.num_rows |
94 | | - ) |
95 | | - rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) |
96 | | - result = rs.fetchall() |
97 | | - assert len(result) == 3 |
98 | | - assert isinstance(result[0], Row) |
99 | | - assert result[0].col_int == 1 |
100 | | - assert result[0].col_str == "a" |
101 | | - assert isinstance(sample_arrow_table.column(0)[0].as_py(), int) |
102 | | - assert isinstance(sample_arrow_table.column(1)[0].as_py(), str) |
103 | | - |
104 | | - |
105 | | -def test_convert_arrow_table_type_override( |
106 | | - mock_connection, mock_thrift_backend, mock_raw_execute_response, sample_arrow_table |
107 | | -): |
108 | | - mock_connection._arrow_pandas_type_override = { |
109 | | - pyarrow.int32(): pandas.Float64Dtype() |
110 | | - } |
111 | | - mock_raw_execute_response.arrow_queue = ArrowQueue( |
112 | | - sample_arrow_table, sample_arrow_table.num_rows |
113 | | - ) |
114 | | - rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) |
115 | | - result = rs.fetchall() |
116 | | - assert len(result) == 3 |
117 | | - assert isinstance(result[0].col_int, float) |
118 | | - assert result[0].col_int == 1.0 |
119 | | - assert result[0].col_str == "a" |
120 | | - |
121 | | - |
122 | | -def test_convert_arrow_table_to_pandas_kwargs( |
123 | | - mock_connection, mock_thrift_backend, mock_raw_execute_response |
124 | | -): |
125 | | - dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) |
126 | | - ts_array = pyarrow.array([dt_obj], type=pyarrow.timestamp("us", tz="UTC")) |
127 | | - ts_schema = pyarrow.schema([("col_ts", pyarrow.timestamp("us", tz="UTC"))]) |
128 | | - ts_table = pyarrow.Table.from_arrays([ts_array], schema=ts_schema) |
129 | | - |
130 | | - mock_raw_execute_response.description = [ |
131 | | - ("col_ts", "timestamp", None, None, None, None, None) |
132 | | - ] |
133 | | - mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) |
134 | | - |
135 | | - # Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. |
136 | | - mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True} |
137 | | - rs_ts_true = ResultSet( |
138 | | - mock_connection, mock_raw_execute_response, mock_thrift_backend |
139 | | - ) |
140 | | - result_true = rs_ts_true.fetchall() |
141 | | - assert len(result_true) == 1 |
142 | | - assert isinstance(result_true[0].col_ts, datetime.datetime) |
143 | | - |
144 | | - # Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. |
145 | | - mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) |
146 | | - mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} |
147 | | - rs_ts_false = ResultSet( |
148 | | - mock_connection, mock_raw_execute_response, mock_thrift_backend |
149 | | - ) |
150 | | - result_false = rs_ts_false.fetchall() |
151 | | - assert len(result_false) == 1 |
152 | | - assert isinstance(result_false[0].col_ts, pandas.Timestamp) |
153 | | - |
154 | | - # Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. |
155 | | - mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) |
156 | | - mock_connection._arrow_to_pandas_kwargs = {} |
157 | | - rs_ts_true = ResultSet( |
158 | | - mock_connection, mock_raw_execute_response, mock_thrift_backend |
159 | | - ) |
160 | | - result_true = rs_ts_true.fetchall() |
161 | | - assert len(result_true) == 1 |
162 | | - assert isinstance(result_true[0].col_ts, datetime.datetime) |
| 24 | + return conn |
| 25 | + |
| 26 | + @staticmethod |
| 27 | + def sample_arrow_table_static(): |
| 28 | + data = [ |
| 29 | + pa.array([1, 2, 3], type=pa.int32()), |
| 30 | + pa.array(["a", "b", "c"], type=pa.string()), |
| 31 | + ] |
| 32 | + schema = pa.schema([("col_int", pa.int32()), ("col_str", pa.string())]) |
| 33 | + return pa.Table.from_arrays(data, schema=schema) |
| 34 | + |
| 35 | + @staticmethod |
| 36 | + def mock_thrift_backend_static(): |
| 37 | + sample_table = ArrowConversionTests.sample_arrow_table_static() |
| 38 | + tb = MagicMock() |
| 39 | + empty_arrays = [pa.array([], type=field.type) for field in sample_table.schema] |
| 40 | + empty_table = pa.Table.from_arrays(empty_arrays, schema=sample_table.schema) |
| 41 | + tb.fetch_results.return_value = (ArrowQueue(empty_table, 0), False) |
| 42 | + return tb |
| 43 | + |
| 44 | + @staticmethod |
| 45 | + def mock_raw_execute_response_static(): |
| 46 | + er = MagicMock(spec=ExecuteResponse) |
| 47 | + er.description = [ |
| 48 | + ("col_int", "int", None, None, None, None, None), |
| 49 | + ("col_str", "string", None, None, None, None, None), |
| 50 | + ] |
| 51 | + er.arrow_schema_bytes = None |
| 52 | + er.arrow_queue = None |
| 53 | + er.has_more_rows = False |
| 54 | + er.lz4_compressed = False |
| 55 | + er.command_handle = MagicMock() |
| 56 | + er.status = MagicMock() |
| 57 | + er.has_been_closed_server_side = False |
| 58 | + er.is_staging_operation = False |
| 59 | + return er |
| 60 | + |
| 61 | + def test_convert_arrow_table_default(self): |
| 62 | + mock_connection = ArrowConversionTests.mock_connection_static() |
| 63 | + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() |
| 64 | + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() |
| 65 | + mock_raw_execute_response = ( |
| 66 | + ArrowConversionTests.mock_raw_execute_response_static() |
| 67 | + ) |
| 68 | + |
| 69 | + mock_raw_execute_response.arrow_queue = ArrowQueue( |
| 70 | + sample_arrow_table, sample_arrow_table.num_rows |
| 71 | + ) |
| 72 | + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) |
| 73 | + result_one = rs.fetchone() |
| 74 | + self.assertIsInstance(result_one, Row) |
| 75 | + self.assertEqual(result_one.col_int, 1) |
| 76 | + self.assertEqual(result_one.col_str, "a") |
| 77 | + |
| 78 | + mock_raw_execute_response.arrow_queue = ArrowQueue( |
| 79 | + sample_arrow_table, sample_arrow_table.num_rows |
| 80 | + ) |
| 81 | + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) |
| 82 | + result_all = rs.fetchall() |
| 83 | + self.assertEqual(len(result_all), 3) |
| 84 | + self.assertIsInstance(result_all[0], Row) |
| 85 | + self.assertEqual(result_all[0].col_int, 1) |
| 86 | + self.assertEqual(result_all[1].col_str, "b") |
| 87 | + |
| 88 | + def test_convert_arrow_table_disable_pandas(self): |
| 89 | + mock_connection = ArrowConversionTests.mock_connection_static() |
| 90 | + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() |
| 91 | + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() |
| 92 | + mock_raw_execute_response = ( |
| 93 | + ArrowConversionTests.mock_raw_execute_response_static() |
| 94 | + ) |
| 95 | + |
| 96 | + mock_connection.disable_pandas = True |
| 97 | + mock_raw_execute_response.arrow_queue = ArrowQueue( |
| 98 | + sample_arrow_table, sample_arrow_table.num_rows |
| 99 | + ) |
| 100 | + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) |
| 101 | + result = rs.fetchall() |
| 102 | + self.assertEqual(len(result), 3) |
| 103 | + self.assertIsInstance(result[0], Row) |
| 104 | + self.assertEqual(result[0].col_int, 1) |
| 105 | + self.assertEqual(result[0].col_str, "a") |
| 106 | + self.assertIsInstance(sample_arrow_table.column(0)[0].as_py(), int) |
| 107 | + self.assertIsInstance(sample_arrow_table.column(1)[0].as_py(), str) |
| 108 | + |
| 109 | + def test_convert_arrow_table_type_override(self): |
| 110 | + mock_connection = ArrowConversionTests.mock_connection_static() |
| 111 | + sample_arrow_table = ArrowConversionTests.sample_arrow_table_static() |
| 112 | + mock_thrift_backend = ArrowConversionTests.mock_thrift_backend_static() |
| 113 | + mock_raw_execute_response = ( |
| 114 | + ArrowConversionTests.mock_raw_execute_response_static() |
| 115 | + ) |
| 116 | + |
| 117 | + mock_connection._arrow_pandas_type_override = { |
| 118 | + pa.int32(): pandas.Float64Dtype() |
| 119 | + } |
| 120 | + mock_raw_execute_response.arrow_queue = ArrowQueue( |
| 121 | + sample_arrow_table, sample_arrow_table.num_rows |
| 122 | + ) |
| 123 | + rs = ResultSet(mock_connection, mock_raw_execute_response, mock_thrift_backend) |
| 124 | + result = rs.fetchall() |
| 125 | + self.assertEqual(len(result), 3) |
| 126 | + self.assertIsInstance(result[0].col_int, float) |
| 127 | + self.assertEqual(result[0].col_int, 1.0) |
| 128 | + self.assertEqual(result[0].col_str, "a") |
| 129 | + |
| 130 | + def test_convert_arrow_table_to_pandas_kwargs(self): |
| 131 | + mock_connection = ArrowConversionTests.mock_connection_static() |
| 132 | + mock_thrift_backend = ( |
| 133 | + ArrowConversionTests.mock_thrift_backend_static() |
| 134 | + ) # Does not use sample_arrow_table |
| 135 | + mock_raw_execute_response = ( |
| 136 | + ArrowConversionTests.mock_raw_execute_response_static() |
| 137 | + ) |
| 138 | + |
| 139 | + dt_obj = datetime.datetime(2021, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc) |
| 140 | + ts_array = pa.array([dt_obj], type=pa.timestamp("us", tz="UTC")) |
| 141 | + ts_schema = pa.schema([("col_ts", pa.timestamp("us", tz="UTC"))]) |
| 142 | + ts_table = pa.Table.from_arrays([ts_array], schema=ts_schema) |
| 143 | + |
| 144 | + mock_raw_execute_response.description = [ |
| 145 | + ("col_ts", "timestamp", None, None, None, None, None) |
| 146 | + ] |
| 147 | + mock_raw_execute_response.arrow_queue = ArrowQueue(ts_table, ts_table.num_rows) |
| 148 | + |
| 149 | + # Scenario 1: timestamp_as_object = True. Observed as datetime.datetime in Row. |
| 150 | + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": True} |
| 151 | + rs_ts_true = ResultSet( |
| 152 | + mock_connection, mock_raw_execute_response, mock_thrift_backend |
| 153 | + ) |
| 154 | + result_true = rs_ts_true.fetchall() |
| 155 | + self.assertEqual(len(result_true), 1) |
| 156 | + self.assertIsInstance(result_true[0].col_ts, datetime.datetime) |
| 157 | + |
| 158 | + # Scenario 2: timestamp_as_object = False. Observed as pandas.Timestamp in Row for this input. |
| 159 | + mock_raw_execute_response.arrow_queue = ArrowQueue( |
| 160 | + ts_table, ts_table.num_rows |
| 161 | + ) # Reset queue |
| 162 | + mock_connection._arrow_to_pandas_kwargs = {"timestamp_as_object": False} |
| 163 | + rs_ts_false = ResultSet( |
| 164 | + mock_connection, mock_raw_execute_response, mock_thrift_backend |
| 165 | + ) |
| 166 | + result_false = rs_ts_false.fetchall() |
| 167 | + self.assertEqual(len(result_false), 1) |
| 168 | + self.assertIsInstance(result_false[0].col_ts, pandas.Timestamp) |
| 169 | + |
| 170 | + # Scenario 3: no override. Observed as datetime.datetime in Row since timestamp_as_object is True by default. |
| 171 | + mock_raw_execute_response.arrow_queue = ArrowQueue( |
| 172 | + ts_table, ts_table.num_rows |
| 173 | + ) # Reset queue |
| 174 | + mock_connection._arrow_to_pandas_kwargs = {} |
| 175 | + rs_ts_default = ResultSet( |
| 176 | + mock_connection, mock_raw_execute_response, mock_thrift_backend |
| 177 | + ) |
| 178 | + result_default = rs_ts_default.fetchall() |
| 179 | + self.assertEqual(len(result_default), 1) |
| 180 | + self.assertIsInstance(result_default[0].col_ts, datetime.datetime) |
| 181 | + |
| 182 | + |
| 183 | +if __name__ == "__main__": |
| 184 | + unittest.main() |
0 commit comments