Skip to content

Commit 24d315e

Browse files
authored
Merge pull request #78 from awslabs/decimal
add Decimal type support
2 parents 26bcddf + 2b4fa3f commit 24d315e

File tree

4 files changed

+134
-2
lines changed

4 files changed

+134
-2
lines changed

awswrangler/data_types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def athena2pandas(dtype: str) -> str:
2626
return "date"
2727
elif dtype == "array":
2828
return "list"
29+
elif dtype == "decimal":
30+
return "decimal"
2931
else:
3032
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
3133

@@ -162,6 +164,8 @@ def pyarrow2athena(dtype: pa.types) -> str:
162164
return "timestamp"
163165
elif dtype_str.startswith("date"):
164166
return "date"
167+
elif dtype_str.startswith("decimal"):
168+
return dtype_str.replace(" ", "")
165169
elif dtype_str.startswith("list"):
166170
return f"array<{pyarrow2athena(dtype.value_type)}>"
167171
elif dtype_str == "null":
@@ -190,6 +194,8 @@ def pyarrow2redshift(dtype: pa.types) -> str:
190194
return "TIMESTAMP"
191195
elif dtype_str.startswith("date"):
192196
return "DATE"
197+
elif dtype_str.startswith("decimal"):
198+
return dtype_str.replace(" ", "").upper()
193199
else:
194200
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
195201

@@ -280,6 +286,8 @@ def spark2redshift(dtype: str) -> str:
280286
return "DATE"
281287
elif dtype == "string":
282288
return "VARCHAR(256)"
289+
elif dtype.startswith("decimal"):
290+
return dtype.replace(" ", "").upper()
283291
else:
284292
raise UnsupportedType("Unsupported Spark type: " + dtype)
285293

awswrangler/pandas.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import copy
77
import csv
88
from datetime import datetime
9-
import ast
9+
from decimal import Decimal
1010

1111
from botocore.exceptions import ClientError, HTTPClientError # type: ignore
1212
import pandas as pd # type: ignore
@@ -471,6 +471,8 @@ def _get_query_dtype(self, query_execution_id: str) -> Tuple[Dict[str, str], Lis
471471
converters[col_name] = Pandas._list_parser
472472
elif pandas_type == "bool":
473473
logger.debug(f"Ignoring bool column: {col_name}")
474+
elif pandas_type == "decimal":
475+
converters[col_name] = lambda x: Decimal(str(x)) if str(x) != "" else None
474476
else:
475477
dtype[col_name] = pandas_type
476478
logger.debug(f"dtype: {dtype}")

testing/test_awswrangler/test_pandas.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import csv
44
from datetime import datetime, date
5+
from decimal import Decimal
56

67
import pytest
78
import boto3
@@ -1303,7 +1304,6 @@ def test_to_parquet_array(session, bucket, database):
13031304
df2 = session.pandas.read_sql_athena(sql="select * from test", database=database)
13041305
if len(df.index) == len(df2.index):
13051306
break
1306-
print(df2)
13071307
session.s3.delete_objects(path=path)
13081308

13091309
assert len(list(df.columns)) == len(list(df2.columns))
@@ -1312,3 +1312,37 @@ def test_to_parquet_array(session, bucket, database):
13121312
assert df2[df2.a == 2].iloc[0].b[0] == 4.0
13131313
assert df2[df2.a == 2].iloc[0].c[0] == 7
13141314
assert df2[df2.a == 2].iloc[0].d[0] == "foo"
1315+
1316+
1317+
def test_to_parquet_decimal(session, bucket, database):
1318+
df = pd.DataFrame({
1319+
"id": [1, 2, 3],
1320+
"decimal_2": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))],
1321+
"decimal_5": [Decimal((0, (1, 9, 9, 9, 9, 9), -5)), None, Decimal((0, (1, 9, 0, 0, 0, 0), -5))],
1322+
})
1323+
print(df)
1324+
print(df.dtypes)
1325+
path = f"s3://{bucket}/test/"
1326+
session.pandas.to_parquet(dataframe=df,
1327+
database=database,
1328+
path=path,
1329+
mode="overwrite",
1330+
preserve_index=False,
1331+
procs_cpu_bound=1)
1332+
df2 = None
1333+
for counter in range(10): # Retrying to workaround s3 eventual consistency
1334+
sleep(1)
1335+
df2 = session.pandas.read_sql_athena(sql="select * from test", database=database)
1336+
if len(df.index) == len(df2.index):
1337+
break
1338+
session.s3.delete_objects(path=path)
1339+
1340+
assert len(list(df.columns)) == len(list(df2.columns))
1341+
assert len(df.index) == len(df2.index)
1342+
1343+
assert df2[df2.id == 1].iloc[0].decimal_2 == Decimal((0, (1, 9, 9), -2))
1344+
assert df2[df2.id == 1].iloc[0].decimal_5 == Decimal((0, (1, 9, 9, 9, 9, 9), -5))
1345+
assert df2[df2.id == 2].iloc[0].decimal_2 is None
1346+
assert df2[df2.id == 2].iloc[0].decimal_5 is None
1347+
assert df2[df2.id == 3].iloc[0].decimal_2 == Decimal((0, (1, 9, 0), -2))
1348+
assert df2[df2.id == 3].iloc[0].decimal_5 == Decimal((0, (1, 9, 0, 0, 0, 0), -5))

testing/test_awswrangler/test_redshift.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import logging
33
from datetime import date, datetime
4+
from decimal import Decimal
45

56
import pytest
67
import boto3
@@ -421,3 +422,90 @@ def test_connection_with_different_port_types(redshift_parameters):
421422
password=redshift_parameters.get("RedshiftPassword"),
422423
)
423424
conn.close()
425+
426+
427+
def test_to_redshift_pandas_decimal(session, bucket, redshift_parameters):
428+
df = pd.DataFrame({
429+
"id": [1, 2, 3],
430+
"decimal_2": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))],
431+
"decimal_5": [Decimal((0, (1, 9, 9, 9, 9, 9), -5)), None, Decimal((0, (1, 9, 0, 0, 0, 0), -5))],
432+
})
433+
con = Redshift.generate_connection(
434+
database="test",
435+
host=redshift_parameters.get("RedshiftAddress"),
436+
port=redshift_parameters.get("RedshiftPort"),
437+
user="test",
438+
password=redshift_parameters.get("RedshiftPassword"),
439+
)
440+
path = f"s3://{bucket}/redshift-load/"
441+
session.pandas.to_redshift(
442+
dataframe=df,
443+
path=path,
444+
schema="public",
445+
table="test",
446+
connection=con,
447+
iam_role=redshift_parameters.get("RedshiftRole"),
448+
mode="overwrite",
449+
preserve_index=False,
450+
)
451+
cursor = con.cursor()
452+
cursor.execute("SELECT * from public.test")
453+
rows = cursor.fetchall()
454+
cursor.close()
455+
con.close()
456+
assert len(df.index) == len(rows)
457+
assert len(list(df.columns)) == len(list(rows[0]))
458+
print(rows)
459+
for row in rows:
460+
if row[0] == 1:
461+
assert row[1] == Decimal((0, (1, 9, 9), -2))
462+
assert row[2] == Decimal((0, (1, 9, 9, 9, 9, 9), -5))
463+
elif row[1] == 2:
464+
assert row[1] is None
465+
assert row[2] is None
466+
elif row[2] == 3:
467+
assert row[1] == Decimal((0, (1, 9, 0), -2))
468+
assert row[2] == Decimal((0, (1, 9, 0, 0, 0, 0), -5))
469+
470+
471+
def test_to_redshift_spark_decimal(session, bucket, redshift_parameters):
472+
df = session.spark_session.createDataFrame(pd.DataFrame({
473+
"id": [1, 2, 3],
474+
"decimal_2": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))],
475+
"decimal_5": [Decimal((0, (1, 9, 9, 9, 9, 9), -5)), None, Decimal((0, (1, 9, 0, 0, 0, 0), -5))]}),
476+
schema="id INTEGER, decimal_2 DECIMAL(3,2), decimal_5 DECIMAL(6,5)")
477+
con = Redshift.generate_connection(
478+
database="test",
479+
host=redshift_parameters.get("RedshiftAddress"),
480+
port=redshift_parameters.get("RedshiftPort"),
481+
user="test",
482+
password=redshift_parameters.get("RedshiftPassword"),
483+
)
484+
path = f"s3://{bucket}/redshift-load2/"
485+
session.spark.to_redshift(
486+
dataframe=df,
487+
path=path,
488+
schema="public",
489+
table="test2",
490+
connection=con,
491+
iam_role=redshift_parameters.get("RedshiftRole"),
492+
mode="overwrite",
493+
)
494+
cursor = con.cursor()
495+
cursor.execute("SELECT * from public.test2")
496+
rows = cursor.fetchall()
497+
cursor.close()
498+
con.close()
499+
assert df.count() == len(rows)
500+
assert len(list(df.columns)) == len(list(rows[0]))
501+
print(rows)
502+
for row in rows:
503+
if row[0] == 1:
504+
assert row[1] == Decimal((0, (1, 9, 9), -2))
505+
assert row[2] == Decimal((0, (1, 9, 9, 9, 9, 9), -5))
506+
elif row[1] == 2:
507+
assert row[1] is None
508+
assert row[2] is None
509+
elif row[2] == 3:
510+
assert row[1] == Decimal((0, (1, 9, 0), -2))
511+
assert row[2] == Decimal((0, (1, 9, 0, 0, 0, 0), -5))

0 commit comments

Comments
 (0)