Skip to content

Commit 26bcddf

Browse files
authored
Merge pull request #77 from awslabs/pandas-read-athena-array
Fix Pandas.read_sql_athena() for arrays
2 parents f100a78 + 0c91d4b commit 26bcddf

File tree

4 files changed

+57
-6
lines changed

4 files changed

+57
-6
lines changed

awswrangler/athena.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def get_query_columns_metadata(self, query_execution_id: str) -> Dict[str, str]:
2525
"""
2626
response: Dict = self._client_athena.get_query_results(QueryExecutionId=query_execution_id, MaxResults=1)
2727
col_info: List[Dict[str, str]] = response["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]
28+
logger.debug(f"col_info: {col_info}")
2829
return {x["Name"]: x["Type"] for x in col_info}
2930

3031
def create_athena_bucket(self):

awswrangler/data_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def athena2pandas(dtype: str) -> str:
2525
elif dtype == "date":
2626
return "date"
2727
elif dtype == "array":
28-
return "literal_eval"
28+
return "list"
2929
else:
3030
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
3131

awswrangler/pandas.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,39 @@ def _read_csv_once(
418418
buff.close()
419419
return dataframe
420420

421+
@staticmethod
422+
def _list_parser(value: str) -> List[Union[int, float, str, None]]:
423+
if len(value) <= 1:
424+
return []
425+
items: List[None, str] = [None if x == "null" else x for x in value[1:-1].split(", ")]
426+
array_type: Optional[type] = None
427+
428+
# check if all values are integers
429+
for item in items:
430+
if item is not None:
431+
try:
432+
int(item) # type: ignore
433+
except ValueError:
434+
break
435+
else:
436+
array_type = int
437+
438+
# check if all values are floats
439+
if array_type is None:
440+
for item in items:
441+
if item is not None:
442+
try:
443+
float(item) # type: ignore
444+
except ValueError:
445+
break
446+
else:
447+
array_type = float
448+
449+
# check if all values are strings
450+
array_type = str if array_type is None else array_type
451+
452+
return [array_type(x) if x is not None else None for x in items]
453+
421454
def _get_query_dtype(self, query_execution_id: str) -> Tuple[Dict[str, str], List[str], List[str], Dict[str, Any]]:
422455
cols_metadata: Dict[str, str] = self._session.athena.get_query_columns_metadata(
423456
query_execution_id=query_execution_id)
@@ -434,15 +467,16 @@ def _get_query_dtype(self, query_execution_id: str) -> Tuple[Dict[str, str], Lis
434467
parse_timestamps.append(col_name)
435468
if pandas_type == "date":
436469
parse_dates.append(col_name)
437-
elif pandas_type == "literal_eval":
438-
converters[col_name] = ast.literal_eval
470+
elif pandas_type == "list":
471+
converters[col_name] = Pandas._list_parser
439472
elif pandas_type == "bool":
440473
logger.debug(f"Ignoring bool column: {col_name}")
441474
else:
442475
dtype[col_name] = pandas_type
443476
logger.debug(f"dtype: {dtype}")
444477
logger.debug(f"parse_timestamps: {parse_timestamps}")
445478
logger.debug(f"parse_dates: {parse_dates}")
479+
logger.debug(f"converters: {converters}")
446480
return dtype, parse_timestamps, parse_dates, converters
447481

448482
def read_sql_athena(self, sql, database=None, s3_output=None, max_result_size=None, workgroup=None,

testing/test_awswrangler/test_pandas.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,9 +1285,10 @@ def test_to_parquet_date_null_at_first(session, bucket, database):
12851285
def test_to_parquet_array(session, bucket, database):
12861286
df = pd.DataFrame({
12871287
"A": [1, 2, 3],
1288-
"B": [[], [4, 5, 6], []],
1289-
"C": [[], ["foo", "boo", "bar"], []],
1290-
"D": [7, 8, 9]
1288+
"B": [[], [4.0, None, 6.0], []],
1289+
"C": [[], [7, None, 9], []],
1290+
"D": [[], ["foo", None, "bar"], []],
1291+
"E": [10, 11, 12]
12911292
})
12921293
path = f"s3://{bucket}/test/"
12931294
session.pandas.to_parquet(dataframe=df,
@@ -1296,3 +1297,18 @@ def test_to_parquet_array(session, bucket, database):
12961297
mode="overwrite",
12971298
preserve_index=False,
12981299
procs_cpu_bound=1)
1300+
df2 = None
1301+
for counter in range(10): # Retrying to workaround s3 eventual consistency
1302+
sleep(1)
1303+
df2 = session.pandas.read_sql_athena(sql="select * from test", database=database)
1304+
if len(df.index) == len(df2.index):
1305+
break
1306+
print(df2)
1307+
session.s3.delete_objects(path=path)
1308+
1309+
assert len(list(df.columns)) == len(list(df2.columns))
1310+
assert len(df.index) == len(df2.index)
1311+
1312+
assert df2[df2.a == 2].iloc[0].b[0] == 4.0
1313+
assert df2[df2.a == 2].iloc[0].c[0] == 7
1314+
assert df2[df2.a == 2].iloc[0].d[0] == "foo"

0 commit comments

Comments
 (0)