Skip to content

Commit 46486b5

Browse files
authored
Merge pull request #61 from awslabs/fix-cast
Fixing cast issues
2 parents 2e52044 + 1e738bc commit 46486b5

File tree

4 files changed

+169
-4
lines changed

4 files changed

+169
-4
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
> Utility belt to handle data on AWS.
44
5-
[![Release](https://img.shields.io/badge/release-0.0.17-brightgreen.svg)](https://pypi.org/project/awswrangler/)
5+
[![Release](https://img.shields.io/badge/release-0.0.18-brightgreen.svg)](https://pypi.org/project/awswrangler/)
66
[![Downloads](https://img.shields.io/pypi/dm/awswrangler.svg)](https://pypi.org/project/awswrangler/)
77
[![Python Version](https://img.shields.io/badge/python-3.6%20%7C%203.7-brightgreen.svg)](https://pypi.org/project/awswrangler/)
88
[![Documentation Status](https://readthedocs.org/projects/aws-data-wrangler/badge/?version=latest)](https://aws-data-wrangler.readthedocs.io/en/latest/?badge=latest)
9-
[![Coverage](https://img.shields.io/badge/coverage-88%25-brightgreen.svg)](https://pypi.org/project/awswrangler/)
9+
[![Coverage](https://img.shields.io/badge/coverage-89%25-brightgreen.svg)](https://pypi.org/project/awswrangler/)
1010
[![Average time to resolve an issue](http://isitmaintained.com/badge/resolution/awslabs/aws-data-wrangler.svg)](http://isitmaintained.com/project/awslabs/aws-data-wrangler "Average time to resolve an issue")
1111
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
1212

awswrangler/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
__title__ = "awswrangler"
22
__description__ = "Utility belt to handle data on AWS."
3-
__version__ = "0.0.17"
3+
__version__ = "0.0.18"
44
__license__ = "Apache License 2.0"

awswrangler/pandas.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,9 @@ def _data_to_s3_dataset_writer(dataframe,
758758
isolated_dataframe=isolated_dataframe)
759759
objects_paths.append(object_path)
760760
else:
761+
dataframe = Pandas._cast_pandas(dataframe=dataframe, cast_columns=cast_columns)
762+
cast_columns_materialized = {c: t for c, t in cast_columns.items() if c not in partition_cols}
763+
dataframe = Pandas._cast_pandas(dataframe=dataframe, cast_columns=cast_columns)
761764
for keys, subgroup in dataframe.groupby(partition_cols):
762765
subgroup = subgroup.drop(partition_cols, axis="columns")
763766
if not isinstance(keys, tuple):
@@ -770,12 +773,24 @@ def _data_to_s3_dataset_writer(dataframe,
770773
compression=compression,
771774
session_primitives=session_primitives,
772775
file_format=file_format,
773-
cast_columns=cast_columns,
776+
cast_columns=cast_columns_materialized,
774777
extra_args=extra_args,
775778
isolated_dataframe=True)
776779
objects_paths.append(object_path)
777780
return objects_paths
778781

782+
@staticmethod
783+
def _cast_pandas(dataframe: pd.DataFrame, cast_columns: Dict[str, str]) -> pd.DataFrame:
784+
for col, athena_type in cast_columns.items():
785+
pandas_type: str = data_types.athena2pandas(dtype=athena_type)
786+
if pandas_type == "datetime64":
787+
dataframe[col] = pd.to_datetime(dataframe[col])
788+
elif pandas_type == "date":
789+
dataframe[col] = pd.to_datetime(dataframe[col]).dt.date
790+
else:
791+
dataframe[col] = dataframe[col].astype(pandas_type)
792+
return dataframe
793+
779794
@staticmethod
780795
def _data_to_s3_dataset_writer_remote(send_pipe,
781796
dataframe,

testing/test_awswrangler/test_pandas.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -949,3 +949,153 @@ def test_read_sql_athena_with_nulls(session, bucket, database):
949949
assert df2.dtypes[1] == "bool"
950950
assert df2.dtypes[2] == "object"
951951
session.s3.delete_objects(path=path)
952+
953+
954+
def test_partition_date(session, bucket, database):
955+
df = pd.DataFrame({
956+
"col1": ["val1", "val2"],
957+
"datecol": ["2019-11-09", "2019-11-08"],
958+
'partcol': ["2019-11-09", "2019-11-08"]
959+
})
960+
df["datecol"] = pd.to_datetime(df.datecol).dt.date
961+
df["partcol"] = pd.to_datetime(df.partcol).dt.date
962+
print(df)
963+
print(df.dtypes)
964+
path = f"s3://{bucket}/test/"
965+
session.pandas.to_parquet(dataframe=df,
966+
database=database,
967+
path=path,
968+
partition_cols=["datecol"],
969+
preserve_index=False,
970+
mode="overwrite")
971+
df2 = None
972+
for counter in range(10):
973+
df2 = session.pandas.read_sql_athena(sql="select * from test", database=database)
974+
assert len(list(df.columns)) == len(list(df2.columns))
975+
if len(df.index) == len(df2.index):
976+
break
977+
sleep(1)
978+
assert len(df.index) == len(df2.index)
979+
print(df2)
980+
print(df2.dtypes)
981+
assert df2.dtypes[0] == "object"
982+
assert df2.dtypes[1] == "object"
983+
assert df2.dtypes[2] == "object"
984+
session.s3.delete_objects(path=path)
985+
986+
987+
def test_partition_cast_date(session, bucket, database):
988+
df = pd.DataFrame({
989+
"col1": ["val1", "val2"],
990+
"datecol": ["2019-11-09", "2019-11-08"],
991+
"partcol": ["2019-11-09", "2019-11-08"]
992+
})
993+
print(df)
994+
print(df.dtypes)
995+
path = f"s3://{bucket}/test/"
996+
schema = {
997+
"col1": "string",
998+
"datecol": "date",
999+
"partcol": "date",
1000+
}
1001+
session.pandas.to_parquet(dataframe=df,
1002+
database=database,
1003+
path=path,
1004+
partition_cols=["partcol"],
1005+
preserve_index=False,
1006+
cast_columns=schema,
1007+
mode="overwrite")
1008+
df2 = None
1009+
for counter in range(10):
1010+
df2 = session.pandas.read_sql_athena(sql="select * from test", database=database)
1011+
assert len(list(df.columns)) == len(list(df2.columns))
1012+
if len(df.index) == len(df2.index):
1013+
break
1014+
sleep(1)
1015+
assert len(df.index) == len(df2.index)
1016+
print(df2)
1017+
print(df2.dtypes)
1018+
assert df2.dtypes[0] == "object"
1019+
assert df2.dtypes[1] == "object"
1020+
assert df2.dtypes[2] == "object"
1021+
session.s3.delete_objects(path=path)
1022+
1023+
1024+
def test_partition_cast_timestamp(session, bucket, database):
1025+
df = pd.DataFrame({
1026+
"col1": ["val1", "val2"],
1027+
"datecol": ["2019-11-09", "2019-11-08"],
1028+
"partcol": ["2019-11-09", "2019-11-08"]
1029+
})
1030+
print(df)
1031+
print(df.dtypes)
1032+
path = f"s3://{bucket}/test/"
1033+
schema = {
1034+
"col1": "string",
1035+
"datecol": "timestamp",
1036+
"partcol": "timestamp",
1037+
}
1038+
session.pandas.to_parquet(dataframe=df,
1039+
database=database,
1040+
path=path,
1041+
partition_cols=["partcol"],
1042+
preserve_index=False,
1043+
cast_columns=schema,
1044+
mode="overwrite")
1045+
df2 = None
1046+
for counter in range(10):
1047+
df2 = session.pandas.read_sql_athena(sql="select * from test", database=database)
1048+
assert len(list(df.columns)) == len(list(df2.columns))
1049+
if len(df.index) == len(df2.index):
1050+
break
1051+
sleep(1)
1052+
assert len(df.index) == len(df2.index)
1053+
print(df2)
1054+
print(df2.dtypes)
1055+
assert str(df2.dtypes[0]) == "object"
1056+
assert str(df2.dtypes[1]).startswith("datetime64")
1057+
assert str(df2.dtypes[2]).startswith("datetime64")
1058+
session.s3.delete_objects(path=path)
1059+
1060+
1061+
def test_partition_cast(session, bucket, database):
1062+
df = pd.DataFrame({
1063+
"col1": ["val1", "val2"],
1064+
"datecol": ["2019-11-09", "2019-11-08"],
1065+
"partcol": ["2019-11-09", "2019-11-08"],
1066+
"col_double": ["1.0", "1.1"],
1067+
"col_bool": ["True", "False"],
1068+
})
1069+
print(df)
1070+
print(df.dtypes)
1071+
path = f"s3://{bucket}/test/"
1072+
schema = {
1073+
"col1": "string",
1074+
"datecol": "timestamp",
1075+
"partcol": "timestamp",
1076+
"col_double": "double",
1077+
"col_bool": "boolean",
1078+
}
1079+
session.pandas.to_parquet(dataframe=df,
1080+
database=database,
1081+
path=path,
1082+
partition_cols=["partcol"],
1083+
preserve_index=False,
1084+
cast_columns=schema,
1085+
mode="overwrite")
1086+
df2 = None
1087+
for counter in range(10):
1088+
df2 = session.pandas.read_sql_athena(sql="select * from test", database=database)
1089+
assert len(list(df.columns)) == len(list(df2.columns))
1090+
if len(df.index) == len(df2.index):
1091+
break
1092+
sleep(1)
1093+
assert len(df.index) == len(df2.index)
1094+
print(df2)
1095+
print(df2.dtypes)
1096+
assert df2.dtypes[0] == "object"
1097+
assert str(df2.dtypes[1]).startswith("datetime")
1098+
assert str(df2.dtypes[2]).startswith("float")
1099+
assert str(df2.dtypes[3]).startswith("bool")
1100+
assert str(df2.dtypes[4]).startswith("datetime")
1101+
session.s3.delete_objects(path=path)

0 commit comments

Comments
 (0)