Skip to content

Commit 4fc0af9

Browse files
committed
refactor(cursor, description): return column label as str
1 parent 4e44b9d commit 4fc0af9

File tree

5 files changed

+66
-13
lines changed

5 files changed

+66
-13
lines changed

redshift_connector/cursor.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
from warnings import warn
99

1010
import redshift_connector
11-
from redshift_connector.config import ClientProtocolVersion, table_type_clauses
11+
from redshift_connector.config import (
12+
ClientProtocolVersion,
13+
_client_encoding,
14+
table_type_clauses,
15+
)
1216
from redshift_connector.error import (
1317
MISSING_MODULE_ERROR_MSG,
1418
InterfaceError,
@@ -165,12 +169,17 @@ def truncated_row_desc(self: "Cursor"):
165169
def _getDescription(self: "Cursor") -> typing.Optional[typing.List[typing.Optional[typing.Tuple]]]:
166170
if self.ps is None:
167171
return None
168-
row_desc: typing.List[typing.Dict[str, typing.Union[bytes, int, typing.Callable]]] = self.ps["row_desc"]
172+
row_desc: typing.List[typing.Dict[str, typing.Union[bytes, str, int, typing.Callable]]] = self.ps["row_desc"]
169173
if len(row_desc) == 0:
170174
return None
171175
columns: typing.List[typing.Optional[typing.Tuple]] = []
172176
for col in row_desc:
173-
columns.append((col["label"], col["type_oid"], None, None, None, None, None))
177+
try:
178+
col_name: typing.Union[str, bytes] = typing.cast(bytes, col["label"]).decode(_client_encoding)
179+
except UnicodeError:
180+
warn("failed to decode column name: {}, reverting to bytes".format(col["label"])) # type: ignore
181+
col_name = typing.cast(bytes, col["label"])
182+
columns.append((col_name, col["type_oid"], None, None, None, None, None))
174183
return columns
175184

176185
##
@@ -503,12 +512,6 @@ def fetch_dataframe(self: "Cursor", num: typing.Optional[int] = None) -> typing.
503512

504513
columns: typing.Optional[typing.List[typing.Union[str, bytes]]] = None
505514
try:
506-
columns = [column[0].decode().lower() for column in self.description]
507-
except UnicodeError as e:
508-
warn(
509-
"Unable to decode column names. Byte values will be used for pandas dataframe column labels.",
510-
stacklevel=2,
511-
)
512515
columns = [column[0].lower() for column in self.description]
513516
except:
514517
warn("No row description was found. pandas dataframe will be missing column labels.", stacklevel=2)

test/integration/test_cursor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pytest # type: ignore
2+
3+
import redshift_connector
4+
5+
6+
@pytest.mark.parametrize("col_name", (("apples", "apples"), ("author‎ ", "author\u200e")))
7+
def test_get_description(db_kwargs, col_name):
8+
given_col_name, exp_col_name = col_name
9+
with redshift_connector.connect(**db_kwargs) as conn:
10+
with conn.cursor() as cursor:
11+
cursor.execute("create temp table tmptbl({} int)".format(given_col_name))
12+
cursor.execute("select * from tmptbl")
13+
assert cursor.description is not None
14+
assert cursor.description[0][0] == exp_col_name
15+
16+
17+
@pytest.mark.parametrize(
18+
"col_names",
19+
(
20+
("(c1 int, c2 int, c3 int)", ("c1", "c2", "c3")),
21+
(
22+
"(áppleṣ int, orañges int, passion⁘fruit int, papaya  int, bañanaș int)",
23+
("áppleṣ", "orañges", "passion⁘fruit", "papaya\u205f", "bañanaș"),
24+
),
25+
),
26+
)
27+
def test_get_description_multiple_column_names(db_kwargs, col_names):
28+
given_col_names, exp_col_names = col_names
29+
with redshift_connector.connect(**db_kwargs) as conn:
30+
with conn.cursor() as cursor:
31+
cursor.execute("create temp table tmptbl {}".format(given_col_names))
32+
cursor.execute("select * from tmptbl")
33+
assert cursor.description is not None
34+
35+
for cidx, column in enumerate(cursor.description):
36+
assert column[0] == exp_col_names[cidx]

test/integration/test_dbapi20.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_description(con):
110110
cur.execute("select name from %sbooze" % table_prefix)
111111
assert len(cur.description) == 1, "cursor.description describes too many columns"
112112
assert len(cur.description[0]) == 7, "cursor.description[x] tuples must have 7 elements"
113-
assert cur.description[0][0].lower() == b"name", "cursor.description[x][0] must return column name"
113+
assert cur.description[0][0].lower() == "name", "cursor.description[x][0] must return column name"
114114
assert cur.description[0][1] == driver.STRING, (
115115
"cursor.description[x][1] must return column type. Got %r" % cur.description[0][1]
116116
)

test/integration/test_pandas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_fetch_dataframe(db_table):
5454
cursor.execute("select * from book; ")
5555
result = cursor.fetch_dataframe()
5656
assert result.columns[0] == "bookname"
57+
assert result.columns[1] == "author\u200e"
5758

5859

5960
@pandas_only

test/unit/test_cursor.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,27 @@
1111
IS_SINGLE_DATABASE_METADATA_TOGGLE: typing.List[bool] = [True, False]
1212

1313

14-
test_warn_response_data: typing.List[typing.Tuple[typing.Optional[typing.List[bytes]], str]] = [
15-
([b"ab\xffcd"], "Unable to decode column names. Byte values will be used for pandas dataframe column labels."),
14+
description_warn_response_data: typing.List[typing.Tuple[bytes, str]] = [
15+
(b"ab\xffcd", "failed to decode column name"),
16+
]
17+
18+
19+
@pytest.mark.parametrize("_input", description_warn_response_data)
20+
def test_get_description_warns_user(_input):
21+
data, exp_warning_msg = _input
22+
mock_cursor: Cursor = Cursor.__new__(Cursor)
23+
mock_cursor.__setattr__("ps", {"row_desc": [{"type_oid": 1043, "label": data, "column_name": b"c1"}]})
24+
with pytest.warns(UserWarning, match=exp_warning_msg):
25+
mock_cursor.description
26+
27+
28+
fetch_df_warn_response_data: typing.List[typing.Tuple[typing.Optional[typing.List[bytes]], str]] = [
1629
(None, "No row description was found. pandas dataframe will be missing column labels."),
1730
]
1831

1932

2033
@pandas_only
21-
@pytest.mark.parametrize("_input", test_warn_response_data)
34+
@pytest.mark.parametrize("_input", fetch_df_warn_response_data)
2235
def test_fetch_dataframe_warns_user(_input, mocker):
2336
data, exp_warning_msg = _input
2437
mock_cursor: Cursor = Cursor.__new__(Cursor)

0 commit comments

Comments
 (0)