Skip to content

Commit affbdef

Browse files
authored
feat(cursor): Add method insert_data_bulk (#81)
1 parent 8d19929 commit affbdef

File tree

2 files changed

+207
-1
lines changed

2 files changed

+207
-1
lines changed

redshift_connector/cursor.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,101 @@ def executemany(self: "Cursor", operation, param_sets) -> "Cursor":
239239
self._row_count = -1 if -1 in rowcounts else sum(rowcounts)
240240
return self
241241

242+
def insert_data_bulk(
243+
self: "Cursor", filename, table_name, column_indexes, column_names, delimeter
244+
) -> "Cursor":
245+
246+
"""runs a single bulk insert statement into the database.
247+
248+
This method is native to redshift_connector.
249+
250+
:param filename: str
251+
The name of the file to read from.
252+
:param table_name: str
253+
The name of the table to insert to.
254+
:param column_names:list
255+
The name of the columns in the table to insert to.
256+
:param column_indexes:list
257+
The indexes of the columns in the table to insert to.
258+
:param delimeter: str
259+
The delimeter to use when reading the file.
260+
261+
Returns
262+
263+
-------
264+
The Cursor object used for executing the specified database operation: :class:`Cursor`
265+
266+
"""
267+
if not self.__is_valid_table(table_name):
268+
raise InterfaceError(
269+
"Invalid table name passed to insert_data_bulk: {}".format(table_name)
270+
)
271+
if not self.__has_valid_columns(table_name, column_names):
272+
raise InterfaceError(
273+
"Invalid column names passed to insert_data_bulk: {}".format(table_name)
274+
)
275+
orig_paramstyle = self.paramstyle
276+
import csv
277+
278+
if len(column_names) != len(column_indexes):
279+
raise InterfaceError("Column names and indexes must be the same length")
280+
sql_query = f"INSERT INTO {table_name} ("
281+
sql_query += ", ".join(column_names)
282+
sql_query += ") VALUES "
283+
sql_param_list_template = "(" + ", ".join(["%s"] * len(column_indexes)) + ")"
284+
try:
285+
with open(filename) as csv_file:
286+
reader = csv.reader(csv_file, delimiter=delimeter)
287+
next(reader)
288+
values_list = []
289+
row_count = 0
290+
for row in reader:
291+
for column_index in column_indexes:
292+
values_list.append(row[column_index])
293+
row_count += 1
294+
sql_param_lists = [sql_param_list_template] * row_count
295+
sql_query += ", ".join(sql_param_lists) + ";"
296+
self.execute(sql_query, values_list)
297+
except Exception as e:
298+
raise InterfaceError(e)
299+
finally:
300+
# reset paramstyle to it's original value
301+
self.paramstyle = orig_paramstyle
302+
303+
return self
304+
305+
def __has_valid_columns(
306+
self: "Cursor", table: str, columns: typing.List[str]
307+
) -> bool:
308+
split_table_name: typing.List[str] = table.split(".")
309+
q: str = "select 1 from information_schema.columns where table_name = ? and column_name = ?"
310+
if len(split_table_name) == 2:
311+
q += " and table_schema = ?"
312+
param_list = [
313+
[split_table_name[1], c, split_table_name[0]] for c in columns
314+
]
315+
else:
316+
param_list = [[split_table_name[0], c] for c in columns]
317+
temp = self.paramstyle
318+
self.paramstyle = "qmark"
319+
try:
320+
for params in param_list:
321+
self.execute(q, params)
322+
res = self.fetchone()
323+
if typing.cast(typing.List[int], res)[0] != 1:
324+
raise InterfaceError(
325+
"Invalid column name: {} specified for table: {}".format(
326+
params[1], table
327+
)
328+
)
329+
except:
330+
raise
331+
finally:
332+
# reset paramstyle to it's original value
333+
self.paramstyle = temp
334+
335+
return True
336+
242337
def callproc(self, procname, parameters=None):
243338
args = [] if parameters is None else parameters
244339
operation = "CALL " + self.__sanitize_str(procname) + "(" + ", ".join(["%s" for _ in args]) + ")"

test/unit/test_cursor.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import typing
22
from test.utils import pandas_only
3-
from unittest.mock import Mock, PropertyMock, patch
3+
from unittest.mock import Mock, PropertyMock, patch ,mock_open
44

55
import pytest # type: ignore
66

@@ -249,3 +249,114 @@ def test_get_tables_considers_args(is_single_database_metadata_val, _input, sche
249249
for arg in (schema_pattern, table_name_pattern):
250250
if arg is not None:
251251
assert arg in spy.call_args[0][1]
252+
253+
254+
@pytest.mark.parametrize("indexes, names", [([1], []), ([], ["c1"])])
255+
def test_insert_data_column_names_indexes_mismatch_raises(indexes, names, mocker):
256+
# mock fetchone to return "True" to ensure the table_name and column_name
257+
# validation steps pass
258+
mocker.patch("redshift_connector.Cursor.fetchone", return_value=[1])
259+
260+
mock_cursor: Cursor = Cursor.__new__(Cursor)
261+
# mock out the connection
262+
mock_cursor._c = Mock()
263+
mock_cursor.paramstyle = "qmark"
264+
265+
with pytest.raises(
266+
InterfaceError, match="Column names and indexes must be the same length"
267+
):
268+
mock_cursor.insert_data_bulk(
269+
filename="test_file",
270+
table_name="test_table",
271+
column_indexes=indexes,
272+
column_names=names,
273+
delimeter=",",
274+
)
275+
276+
277+
in_mem_csv = """\
278+
col1,col2,col3
279+
1,3,foo
280+
2,5,bar
281+
-1,7,baz"""
282+
283+
insert_bulk_data = [
284+
(
285+
[0],
286+
["col1"],
287+
("INSERT INTO test_table (col1) VALUES (%s), (%s), (%s);", ["1", "2", "-1"]),
288+
),
289+
(
290+
[1],
291+
["col2"],
292+
("INSERT INTO test_table (col2) VALUES (%s), (%s), (%s);", ["3", "5", "7"]),
293+
),
294+
(
295+
[2],
296+
["col3"],
297+
(
298+
"INSERT INTO test_table (col3) VALUES (%s), (%s), (%s);",
299+
["foo", "bar", "baz"],
300+
),
301+
),
302+
(
303+
[0, 1],
304+
["col1", "col2"],
305+
(
306+
"INSERT INTO test_table (col1, col2) VALUES (%s, %s), (%s, %s), (%s, %s);",
307+
["1", "3", "2", "5", "-1", "7"],
308+
),
309+
),
310+
(
311+
[0, 2],
312+
["col1", "col3"],
313+
(
314+
"INSERT INTO test_table (col1, col3) VALUES (%s, %s), (%s, %s), (%s, %s);",
315+
["1", "foo", "2", "bar", "-1", "baz"],
316+
),
317+
),
318+
(
319+
[1, 2],
320+
["col2", "col3"],
321+
(
322+
"INSERT INTO test_table (col2, col3) VALUES (%s, %s), (%s, %s), (%s, %s);",
323+
["3", "foo", "5", "bar", "7", "baz"],
324+
),
325+
),
326+
(
327+
[0, 1, 2],
328+
["col1", "col2", "col3"],
329+
(
330+
"INSERT INTO test_table (col1, col2, col3) VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s);",
331+
["1", "3", "foo", "2", "5", "bar", "-1", "7", "baz"],
332+
),
333+
),
334+
]
335+
336+
337+
@patch("builtins.open", new_callable=mock_open, read_data=in_mem_csv)
338+
@pytest.mark.parametrize("indexes,names,exp_execute_args", insert_bulk_data)
339+
def test_insert_data_column_stmt(mocked_csv, indexes, names, exp_execute_args, mocker):
340+
# mock fetchone to return "True" to ensure the table_name and column_name
341+
# validation steps pass
342+
mocker.patch("redshift_connector.Cursor.fetchone", return_value=[1])
343+
mock_cursor: Cursor = Cursor.__new__(Cursor)
344+
345+
# spy on the execute method, so we can check value of sql_query
346+
spy = mocker.spy(mock_cursor, "execute")
347+
348+
# mock out the connection
349+
mock_cursor._c = Mock()
350+
mock_cursor.paramstyle = "qmark"
351+
352+
mock_cursor.insert_data_bulk(
353+
filename="mocked_csv",
354+
table_name="test_table",
355+
column_indexes=indexes,
356+
column_names=names,
357+
delimeter=",",
358+
)
359+
360+
assert spy.called is True
361+
assert spy.call_args[0][0] == exp_execute_args[0]
362+
assert spy.call_args[0][1] == exp_execute_args[1]

0 commit comments

Comments
 (0)