Skip to content

Commit 0df629a

Browse files
authored
Merge pull request #66 from awslabs/casting-nan
Fixing cast for nan values
2 parents 715cea1 + 99897cd commit 0df629a

File tree

3 files changed

+66
-14
lines changed

3 files changed

+66
-14
lines changed

awswrangler/pandas.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -587,20 +587,20 @@ def to_parquet(self,
587587
inplace=inplace)
588588

589589
def to_s3(self,
590-
dataframe,
591-
path,
592-
file_format,
593-
database=None,
594-
table=None,
590+
dataframe: pd.DataFrame,
591+
path: str,
592+
file_format: str,
593+
database: Optional[str] = None,
594+
table: Optional[str] = None,
595595
partition_cols=None,
596596
preserve_index=True,
597-
mode="append",
597+
mode: str = "append",
598598
compression=None,
599599
procs_cpu_bound=None,
600600
procs_io_bound=None,
601601
cast_columns=None,
602602
extra_args=None,
603-
inplace=True):
603+
inplace: bool = True) -> List[str]:
604604
"""
605605
Write a Pandas Dataframe on S3
606606
Optionally writes metadata on AWS Glue.
@@ -621,9 +621,9 @@ def to_s3(self,
621621
:param inplace: True is cheapest (CPU and Memory) but False leaves your DataFrame intact
622622
:return: List of objects written on S3
623623
"""
624-
if not partition_cols:
624+
if partition_cols is None:
625625
partition_cols = []
626-
if not cast_columns:
626+
if cast_columns is None:
627627
cast_columns = {}
628628
dataframe = Pandas.normalize_columns_names_athena(dataframe, inplace=inplace)
629629
cast_columns = {Athena.normalize_column_name(k): v for k, v in cast_columns.items()}
@@ -748,20 +748,20 @@ def _data_to_s3_dataset_writer(dataframe,
748748
extra_args=None,
749749
isolated_dataframe=False):
750750
objects_paths = []
751+
dataframe = Pandas._cast_pandas(dataframe=dataframe, cast_columns=cast_columns)
752+
cast_columns_materialized = {c: t for c, t in cast_columns.items() if c not in partition_cols}
751753
if not partition_cols:
752754
object_path = Pandas._data_to_s3_object_writer(dataframe=dataframe,
753755
path=path,
754756
preserve_index=preserve_index,
755757
compression=compression,
756758
session_primitives=session_primitives,
757759
file_format=file_format,
758-
cast_columns=cast_columns,
760+
cast_columns=cast_columns_materialized,
759761
extra_args=extra_args,
760762
isolated_dataframe=isolated_dataframe)
761763
objects_paths.append(object_path)
762764
else:
763-
dataframe = Pandas._cast_pandas(dataframe=dataframe, cast_columns=cast_columns)
764-
cast_columns_materialized = {c: t for c, t in cast_columns.items() if c not in partition_cols}
765765
dataframe = Pandas._cast_pandas(dataframe=dataframe, cast_columns=cast_columns)
766766
for keys, subgroup in dataframe.groupby(partition_cols):
767767
subgroup = subgroup.drop(partition_cols, axis="columns")
@@ -790,7 +790,7 @@ def _cast_pandas(dataframe: pd.DataFrame, cast_columns: Dict[str, str]) -> pd.Da
790790
elif pandas_type == "date":
791791
dataframe[col] = pd.to_datetime(dataframe[col]).dt.date
792792
else:
793-
dataframe[col] = dataframe[col].astype(pandas_type)
793+
dataframe[col] = dataframe[col].astype(pandas_type, skipna=True)
794794
return dataframe
795795

796796
@staticmethod

data_samples/nan.csv

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"col1","col2","col3","col4","pt"
2+
,,,,"1"
3+
,,,,"2"
4+
,"foo","bar","baz","1"
5+
,"foo","bar","baz","2"

testing/test_awswrangler/test_pandas.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ def test_read_sql_athena_with_nulls(session, bucket, database):
947947
print(df2.dtypes)
948948
assert df2.dtypes[0] == "Int64"
949949
assert df2.dtypes[1] == "bool"
950-
assert df2.dtypes[2] == "object"
950+
assert df2.dtypes[2] == "bool"
951951
session.s3.delete_objects(path=path)
952952

953953

@@ -1149,3 +1149,50 @@ def test_partition_single_row(session, bucket, database, procs):
11491149
assert df2.dtypes[1] == "object"
11501150
assert df2.dtypes[2] == "object"
11511151
session.s3.delete_objects(path=path)
1152+
1153+
1154+
@pytest.mark.parametrize("partition_cols", [None, ["pt"]])
1155+
def test_nan_cast(session, bucket, database, partition_cols):
1156+
dtypes = {"col1": "object", "col2": "object", "col3": "object", "col4": "object", "pt": "object"}
1157+
df = pd.read_csv("data_samples/nan.csv", dtype=dtypes)
1158+
print(df)
1159+
schema = {
1160+
"col1": "string",
1161+
"col2": "string",
1162+
"col3": "string",
1163+
"col4": "string",
1164+
"pt": "string",
1165+
}
1166+
path = f"s3://{bucket}/test/"
1167+
session.pandas.to_parquet(dataframe=df,
1168+
database=database,
1169+
path=path,
1170+
partition_cols=partition_cols,
1171+
mode="overwrite",
1172+
cast_columns=schema)
1173+
df2 = None
1174+
for counter in range(10):
1175+
sleep(1)
1176+
df2 = session.pandas.read_sql_athena(sql="select * from test", database=database)
1177+
assert len(list(df.columns)) == len(list(df2.columns)) - 1
1178+
if len(df.index) == len(df2.index):
1179+
break
1180+
print(df2.dtypes)
1181+
assert len(df.index) == len(df2.index)
1182+
assert df2.dtypes[0] == "object"
1183+
assert df2.dtypes[1] == "object"
1184+
assert df2.dtypes[2] == "object"
1185+
assert df2.dtypes[3] == "object"
1186+
assert df2.iloc[:, 0].isna().sum() == 4
1187+
assert df2.iloc[:, 1].isna().sum() == 2
1188+
assert df2.iloc[:, 2].isna().sum() == 2
1189+
assert df2.iloc[:, 3].isna().sum() == 2
1190+
assert df2.iloc[:, 4].isna().sum() == 0
1191+
assert df2.iloc[:, 5].isna().sum() == 0
1192+
if partition_cols is None:
1193+
assert df2.dtypes[4] == "object"
1194+
assert df2.dtypes[5] == "Int64"
1195+
else:
1196+
assert df2.dtypes[4] == "Int64"
1197+
assert df2.dtypes[5] == "object"
1198+
session.s3.delete_objects(path=path)

0 commit comments

Comments
 (0)