Skip to content

Commit 2e52044

Browse files
committed
Bumping version to 0.0.17
1 parent 2c8f4a7 commit 2e52044

File tree

5 files changed

+68
-30
lines changed

5 files changed

+68
-30
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
> Utility belt to handle data on AWS.
44
5-
[![Release](https://img.shields.io/badge/release-0.0.16-brightgreen.svg)](https://pypi.org/project/awswrangler/)
5+
[![Release](https://img.shields.io/badge/release-0.0.17-brightgreen.svg)](https://pypi.org/project/awswrangler/)
66
[![Downloads](https://img.shields.io/pypi/dm/awswrangler.svg)](https://pypi.org/project/awswrangler/)
77
[![Python Version](https://img.shields.io/badge/python-3.6%20%7C%203.7-brightgreen.svg)](https://pypi.org/project/awswrangler/)
88
[![Documentation Status](https://readthedocs.org/projects/aws-data-wrangler/badge/?version=latest)](https://aws-data-wrangler.readthedocs.io/en/latest/?badge=latest)

awswrangler/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
__title__ = "awswrangler"
22
__description__ = "Utility belt to handle data on AWS."
3-
__version__ = "0.0.16"
3+
__version__ = "0.0.17"
44
__license__ = "Apache License 2.0"

awswrangler/athena.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from typing import Dict, List, Tuple, Optional, Any, Iterator
22
from time import sleep
33
import logging
4-
import ast
54
import re
65
import unicodedata
76

8-
from awswrangler.data_types import athena2python, athena2pandas
7+
from awswrangler.data_types import athena2python
98
from awswrangler.exceptions import QueryFailed, QueryCancelled
109

1110
logger = logging.getLogger(__name__)
@@ -18,33 +17,16 @@ def __init__(self, session):
1817
self._session = session
1918
self._client_athena = session.boto3_session.client(service_name="athena", config=session.botocore_config)
2019

21-
def get_query_columns_metadata(self, query_execution_id):
22-
response = self._client_athena.get_query_results(QueryExecutionId=query_execution_id, MaxResults=1)
23-
col_info = response["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]
20+
def get_query_columns_metadata(self, query_execution_id: str) -> Dict[str, str]:
21+
"""
22+
Get the data type of all columns queried
23+
:param query_execution_id: Athena query execution ID
24+
:return: Dictionary with all data types
25+
"""
26+
response: Dict = self._client_athena.get_query_results(QueryExecutionId=query_execution_id, MaxResults=1)
27+
col_info: List[Dict[str, str]] = response["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]
2428
return {x["Name"]: x["Type"] for x in col_info}
2529

26-
def get_query_dtype(self, query_execution_id):
27-
cols_metadata = self.get_query_columns_metadata(query_execution_id=query_execution_id)
28-
logger.debug(f"cols_metadata: {cols_metadata}")
29-
dtype = {}
30-
parse_timestamps = []
31-
parse_dates = []
32-
converters = {}
33-
for col_name, col_type in cols_metadata.items():
34-
pandas_type = athena2pandas(dtype=col_type)
35-
if pandas_type in ["datetime64", "date"]:
36-
parse_timestamps.append(col_name)
37-
if pandas_type == "date":
38-
parse_dates.append(col_name)
39-
elif pandas_type == "literal_eval":
40-
converters[col_name] = ast.literal_eval
41-
else:
42-
dtype[col_name] = pandas_type
43-
logger.debug(f"dtype: {dtype}")
44-
logger.debug(f"parse_timestamps: {parse_timestamps}")
45-
logger.debug(f"parse_dates: {parse_dates}")
46-
return dtype, parse_timestamps, parse_dates, converters
47-
4830
def create_athena_bucket(self):
4931
"""
5032
Creates the default Athena bucket if not exists

awswrangler/pandas.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import copy
77
import csv
88
from datetime import datetime
9+
import ast
910

1011
import pandas as pd # type: ignore
1112
import pyarrow as pa # type: ignore
@@ -416,6 +417,33 @@ def _read_csv_once(
416417
buff.close()
417418
return dataframe
418419

420+
def _get_query_dtype(self, query_execution_id: str) -> Tuple[Dict[str, str], List[str], List[str], Dict[str, Any]]:
421+
cols_metadata: Dict[str, str] = self._session.athena.get_query_columns_metadata(
422+
query_execution_id=query_execution_id)
423+
logger.debug(f"cols_metadata: {cols_metadata}")
424+
dtype: Dict[str, str] = {}
425+
parse_timestamps: List[str] = []
426+
parse_dates: List[str] = []
427+
converters: Dict[str, Any] = {}
428+
col_name: str
429+
col_type: str
430+
for col_name, col_type in cols_metadata.items():
431+
pandas_type: str = data_types.athena2pandas(dtype=col_type)
432+
if pandas_type in ["datetime64", "date"]:
433+
parse_timestamps.append(col_name)
434+
if pandas_type == "date":
435+
parse_dates.append(col_name)
436+
elif pandas_type == "literal_eval":
437+
converters[col_name] = ast.literal_eval
438+
elif pandas_type == "bool":
439+
logger.debug(f"Ignoring bool column: {col_name}")
440+
else:
441+
dtype[col_name] = pandas_type
442+
logger.debug(f"dtype: {dtype}")
443+
logger.debug(f"parse_timestamps: {parse_timestamps}")
444+
logger.debug(f"parse_dates: {parse_dates}")
445+
return dtype, parse_timestamps, parse_dates, converters
446+
419447
def read_sql_athena(self, sql, database, s3_output=None, max_result_size=None):
420448
"""
421449
Executes any SQL query on AWS Athena and return a Dataframe of the result.
@@ -436,7 +464,7 @@ def read_sql_athena(self, sql, database, s3_output=None, max_result_size=None):
436464
message_error = f"Query error: {reason}"
437465
raise AthenaQueryError(message_error)
438466
else:
439-
dtype, parse_timestamps, parse_dates, converters = self._session.athena.get_query_dtype(
467+
dtype, parse_timestamps, parse_dates, converters = self._get_query_dtype(
440468
query_execution_id=query_execution_id)
441469
path = f"{s3_output}{query_execution_id}.csv"
442470
ret = self.read_csv(path=path,

testing/test_awswrangler/test_pandas.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,3 +921,31 @@ def test_to_parquet_casting_with_null_object(
921921
database=database,
922922
path=f"s3://{bucket}/test/",
923923
mode="overwrite")
924+
925+
926+
def test_read_sql_athena_with_nulls(session, bucket, database):
927+
df = pd.DataFrame({"col_int": [1, None, 3], "col_bool": [True, False, False], "col_bool_null": [True, None, False]})
928+
path = f"s3://{bucket}/test/"
929+
session.pandas.to_parquet(dataframe=df,
930+
database=database,
931+
path=path,
932+
preserve_index=False,
933+
mode="overwrite",
934+
cast_columns={
935+
"col_int": "int",
936+
"col_bool_null": "boolean"
937+
})
938+
df2 = None
939+
for counter in range(10):
940+
df2 = session.pandas.read_sql_athena(sql="select * from test", database=database)
941+
assert len(list(df.columns)) == len(list(df2.columns))
942+
if len(df.index) == len(df2.index):
943+
break
944+
sleep(1)
945+
assert len(df.index) == len(df2.index)
946+
print(df2)
947+
print(df2.dtypes)
948+
assert df2.dtypes[0] == "Int64"
949+
assert df2.dtypes[1] == "bool"
950+
assert df2.dtypes[2] == "object"
951+
session.s3.delete_objects(path=path)

0 commit comments

Comments
 (0)