Skip to content

Commit a755c71

Browse files
committed
Bumping version to 0.0.24
1 parent 24d315e commit a755c71

File tree

16 files changed

+262
-270
lines changed

16 files changed

+262
-270
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
> Utility belt to handle data on AWS.
44
5-
[![Release](https://img.shields.io/badge/release-0.0.23-brightgreen.svg)](https://pypi.org/project/awswrangler/)
5+
[![Release](https://img.shields.io/badge/release-0.0.24-brightgreen.svg)](https://pypi.org/project/awswrangler/)
66
[![Downloads](https://img.shields.io/pypi/dm/awswrangler.svg)](https://pypi.org/project/awswrangler/)
77
[![Python Version](https://img.shields.io/badge/python-3.6%20%7C%203.7-brightgreen.svg)](https://pypi.org/project/awswrangler/)
88
[![Documentation Status](https://readthedocs.org/projects/aws-data-wrangler/badge/?version=latest)](https://aws-data-wrangler.readthedocs.io/en/latest/?badge=latest)

awswrangler/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
__title__ = "awswrangler"
22
__description__ = "Utility belt to handle data on AWS."
3-
__version__ = "0.0.23"
3+
__version__ = "0.0.24"
44
__license__ = "Apache License 2.0"

awswrangler/athena.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Tuple, Optional, Any, Iterator, Union
1+
from typing import Dict, List, Tuple, Optional, Any, Iterator
22
from time import sleep
33
import logging
44
import re
@@ -25,7 +25,6 @@ def get_query_columns_metadata(self, query_execution_id: str) -> Dict[str, str]:
2525
"""
2626
response: Dict = self._client_athena.get_query_results(QueryExecutionId=query_execution_id, MaxResults=1)
2727
col_info: List[Dict[str, str]] = response["ResultSet"]["ResultSetMetadata"]["ColumnInfo"]
28-
logger.debug(f"col_info: {col_info}")
2928
return {x["Name"]: x["Type"] for x in col_info}
3029

3130
def create_athena_bucket(self):
@@ -42,7 +41,13 @@ def create_athena_bucket(self):
4241
s3_resource.Bucket(s3_output)
4342
return s3_output
4443

45-
def run_query(self, query: str, database: Optional[str] = None, s3_output: Optional[str] = None, workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None) -> str:
44+
def run_query(self,
45+
query: str,
46+
database: Optional[str] = None,
47+
s3_output: Optional[str] = None,
48+
workgroup: Optional[str] = None,
49+
encryption: Optional[str] = None,
50+
kms_key: Optional[str] = None) -> str:
4651
"""
4752
Run a SQL Query against AWS Athena
4853
P.S All default values will be inherited from the Session()
@@ -55,7 +60,7 @@ def run_query(self, query: str, database: Optional[str] = None, s3_output: Optio
5560
:param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
5661
:return: Query execution ID
5762
"""
58-
args: Dict[str, Union[str, Dict[str, Union[str, Dict[str, str]]]]] = {"QueryString": query}
63+
args: Dict[str, Any] = {"QueryString": query}
5964

6065
# s3_output
6166
if s3_output is None:
@@ -71,7 +76,9 @@ def run_query(self, query: str, database: Optional[str] = None, s3_output: Optio
7176
if kms_key is not None:
7277
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = kms_key
7378
elif self._session.athena_encryption is not None:
74-
args["ResultConfiguration"]["EncryptionConfiguration"] = {"EncryptionOption": self._session.athena_encryption}
79+
args["ResultConfiguration"]["EncryptionConfiguration"] = {
80+
"EncryptionOption": self._session.athena_encryption
81+
}
7582
if self._session.athena_kms_key is not None:
7683
args["ResultConfiguration"]["EncryptionConfiguration"]["KmsKey"] = self._session.athena_kms_key
7784

@@ -113,7 +120,13 @@ def wait_query(self, query_execution_id):
113120
raise QueryCancelled(response["QueryExecution"]["Status"].get("StateChangeReason"))
114121
return response
115122

116-
def repair_table(self, table: str, database: Optional[str] = None, s3_output: Optional[str] = None, workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None):
123+
def repair_table(self,
124+
table: str,
125+
database: Optional[str] = None,
126+
s3_output: Optional[str] = None,
127+
workgroup: Optional[str] = None,
128+
encryption: Optional[str] = None,
129+
kms_key: Optional[str] = None):
117130
"""
118131
Hive's metastore consistency check
119132
"MSCK REPAIR TABLE table;"
@@ -133,7 +146,12 @@ def repair_table(self, table: str, database: Optional[str] = None, s3_output: Op
133146
:return: Query execution ID
134147
"""
135148
query = f"MSCK REPAIR TABLE {table};"
136-
query_id = self.run_query(query=query, database=database, s3_output=s3_output, workgroup=workgroup, encryption=encryption, kms_key=kms_key)
149+
query_id = self.run_query(query=query,
150+
database=database,
151+
s3_output=s3_output,
152+
workgroup=workgroup,
153+
encryption=encryption,
154+
kms_key=kms_key)
137155
self.wait_query(query_execution_id=query_id)
138156
return query_id
139157

@@ -174,7 +192,13 @@ def get_results(self, query_execution_id: str) -> Iterator[Dict[str, Any]]:
174192
yield row
175193
next_token = res.get("NextToken")
176194

177-
def query(self, query: str, database: Optional[str] = None, s3_output: Optional[str] = None, workgroup: Optional[str] = None, encryption: Optional[str] = None, kms_key: Optional[str] = None) -> Iterator[Dict[str, Any]]:
195+
def query(self,
196+
query: str,
197+
database: Optional[str] = None,
198+
s3_output: Optional[str] = None,
199+
workgroup: Optional[str] = None,
200+
encryption: Optional[str] = None,
201+
kms_key: Optional[str] = None) -> Iterator[Dict[str, Any]]:
178202
"""
179203
Run a SQL Query against AWS Athena and return the result as a Iterator of lists
180204
P.S All default values will be inherited from the Session()
@@ -187,7 +211,12 @@ def query(self, query: str, database: Optional[str] = None, s3_output: Optional[
187211
:param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
188212
:return: Query execution ID
189213
"""
190-
query_id: str = self.run_query(query=query, database=database, s3_output=s3_output, workgroup=workgroup, encryption=encryption, kms_key=kms_key)
214+
query_id: str = self.run_query(query=query,
215+
database=database,
216+
s3_output=s3_output,
217+
workgroup=workgroup,
218+
encryption=encryption,
219+
kms_key=kms_key)
191220
self.wait_query(query_execution_id=query_id)
192221
return self.get_results(query_execution_id=query_id)
193222

awswrangler/data_types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,8 @@ def convert_schema(func: Callable, schema: List[Tuple[str, str]]) -> Dict[str, s
304304
return {name: func(dtype) for name, dtype in schema}
305305

306306

307-
def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame, preserve_index: bool,
307+
def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,
308+
preserve_index: bool,
308309
indexes_position: str = "right") -> List[Tuple[str, str]]:
309310
"""
310311
Extract the related Pyarrow schema from any Pandas DataFrame

awswrangler/emr.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,10 @@ def submit_step(self,
480480
logger.info(f"response: \n{json.dumps(response, default=str, indent=4)}")
481481
return response["StepIds"][0]
482482

483-
def build_step(self, name: str, command: str, action_on_failure: str = "CONTINUE",
483+
def build_step(self,
484+
name: str,
485+
command: str,
486+
action_on_failure: str = "CONTINUE",
484487
script: bool = False) -> Dict[str, Collection[str]]:
485488
"""
486489
Build the Step dictionary

awswrangler/pandas.py

Lines changed: 93 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import csv
88
from datetime import datetime
99
from decimal import Decimal
10+
from ast import literal_eval
1011

1112
from botocore.exceptions import ClientError, HTTPClientError # type: ignore
1213
import pandas as pd # type: ignore
@@ -46,24 +47,24 @@ def _parse_path(path):
4647
return parts[0], parts[2]
4748

4849
def read_csv(
49-
self,
50-
path,
51-
max_result_size=None,
52-
header="infer",
53-
names=None,
54-
usecols=None,
55-
dtype=None,
56-
sep=",",
57-
thousands=None,
58-
decimal=".",
59-
lineterminator="\n",
60-
quotechar='"',
61-
quoting=csv.QUOTE_MINIMAL,
62-
escapechar=None,
63-
parse_dates: Union[bool, Dict, List] = False,
64-
infer_datetime_format=False,
65-
encoding="utf-8",
66-
converters=None,
50+
self,
51+
path,
52+
max_result_size=None,
53+
header="infer",
54+
names=None,
55+
usecols=None,
56+
dtype=None,
57+
sep=",",
58+
thousands=None,
59+
decimal=".",
60+
lineterminator="\n",
61+
quotechar='"',
62+
quoting=csv.QUOTE_MINIMAL,
63+
escapechar=None,
64+
parse_dates: Union[bool, Dict, List] = False,
65+
infer_datetime_format=False,
66+
encoding="utf-8",
67+
converters=None,
6768
):
6869
"""
6970
Read CSV file from AWS S3 using optimized strategies.
@@ -137,25 +138,25 @@ def read_csv(
137138

138139
@staticmethod
139140
def _read_csv_iterator(
140-
client_s3,
141-
bucket_name,
142-
key_path,
143-
max_result_size=200_000_000, # 200 MB
144-
header="infer",
145-
names=None,
146-
usecols=None,
147-
dtype=None,
148-
sep=",",
149-
thousands=None,
150-
decimal=".",
151-
lineterminator="\n",
152-
quotechar='"',
153-
quoting=csv.QUOTE_MINIMAL,
154-
escapechar=None,
155-
parse_dates: Union[bool, Dict, List] = False,
156-
infer_datetime_format=False,
157-
encoding="utf-8",
158-
converters=None,
141+
client_s3,
142+
bucket_name,
143+
key_path,
144+
max_result_size=200_000_000, # 200 MB
145+
header="infer",
146+
names=None,
147+
usecols=None,
148+
dtype=None,
149+
sep=",",
150+
thousands=None,
151+
decimal=".",
152+
lineterminator="\n",
153+
quotechar='"',
154+
quoting=csv.QUOTE_MINIMAL,
155+
escapechar=None,
156+
parse_dates: Union[bool, Dict, List] = False,
157+
infer_datetime_format=False,
158+
encoding="utf-8",
159+
converters=None,
159160
):
160161
"""
161162
Read CSV file from AWS S3 using optimized strategies.
@@ -350,24 +351,24 @@ def _find_terminator(body, sep, quoting, quotechar, lineterminator):
350351

351352
@staticmethod
352353
def _read_csv_once(
353-
client_s3,
354-
bucket_name,
355-
key_path,
356-
header="infer",
357-
names=None,
358-
usecols=None,
359-
dtype=None,
360-
sep=",",
361-
thousands=None,
362-
decimal=".",
363-
lineterminator="\n",
364-
quotechar='"',
365-
quoting=0,
366-
escapechar=None,
367-
parse_dates: Union[bool, Dict, List] = False,
368-
infer_datetime_format=False,
369-
encoding=None,
370-
converters=None,
354+
client_s3,
355+
bucket_name,
356+
key_path,
357+
header="infer",
358+
names=None,
359+
usecols=None,
360+
dtype=None,
361+
sep=",",
362+
thousands=None,
363+
decimal=".",
364+
lineterminator="\n",
365+
quotechar='"',
366+
quoting=0,
367+
escapechar=None,
368+
parse_dates: Union[bool, Dict, List] = False,
369+
infer_datetime_format=False,
370+
encoding=None,
371+
converters=None,
371372
):
372373
"""
373374
Read CSV file from AWS S3 using optimized strategies.
@@ -420,9 +421,17 @@ def _read_csv_once(
420421

421422
@staticmethod
422423
def _list_parser(value: str) -> List[Union[int, float, str, None]]:
424+
# try resolve with a simple literal_eval
425+
try:
426+
return literal_eval(value)
427+
except ValueError:
428+
pass # keep trying
429+
430+
# sanity check
423431
if len(value) <= 1:
424432
return []
425-
items: List[None, str] = [None if x == "null" else x for x in value[1:-1].split(", ")]
433+
434+
items: List[Union[None, str]] = [None if x == "null" else x for x in value[1:-1].split(", ")]
426435
array_type: Optional[type] = None
427436

428437
# check if all values are integers
@@ -481,8 +490,14 @@ def _get_query_dtype(self, query_execution_id: str) -> Tuple[Dict[str, str], Lis
481490
logger.debug(f"converters: {converters}")
482491
return dtype, parse_timestamps, parse_dates, converters
483492

484-
def read_sql_athena(self, sql, database=None, s3_output=None, max_result_size=None, workgroup=None,
485-
encryption=None, kms_key=None):
493+
def read_sql_athena(self,
494+
sql,
495+
database=None,
496+
s3_output=None,
497+
max_result_size=None,
498+
workgroup=None,
499+
encryption=None,
500+
kms_key=None):
486501
"""
487502
Executes any SQL query on AWS Athena and return a Dataframe of the result.
488503
P.S. If max_result_size is passed, then a iterator of Dataframes is returned.
@@ -499,7 +514,12 @@ def read_sql_athena(self, sql, database=None, s3_output=None, max_result_size=No
499514
"""
500515
if not s3_output:
501516
s3_output = self._session.athena.create_athena_bucket()
502-
query_execution_id = self._session.athena.run_query(query=sql, database=database, s3_output=s3_output, workgroup=workgroup, encryption=encryption, kms_key=kms_key)
517+
query_execution_id = self._session.athena.run_query(query=sql,
518+
database=database,
519+
s3_output=s3_output,
520+
workgroup=workgroup,
521+
encryption=encryption,
522+
kms_key=kms_key)
503523
query_response = self._session.athena.wait_query(query_execution_id=query_execution_id)
504524
if query_response["QueryExecution"]["Status"]["State"] in ["FAILED", "CANCELLED"]:
505525
reason = query_response["QueryExecution"]["Status"]["StateChangeReason"]
@@ -532,19 +552,19 @@ def _apply_dates_to_generator(generator, parse_dates):
532552
yield df
533553

534554
def to_csv(
535-
self,
536-
dataframe,
537-
path,
538-
sep=",",
539-
serde="OpenCSVSerDe",
540-
database: Optional[str] = None,
541-
table=None,
542-
partition_cols=None,
543-
preserve_index=True,
544-
mode="append",
545-
procs_cpu_bound=None,
546-
procs_io_bound=None,
547-
inplace=True,
555+
self,
556+
dataframe,
557+
path,
558+
sep=",",
559+
serde="OpenCSVSerDe",
560+
database: Optional[str] = None,
561+
table=None,
562+
partition_cols=None,
563+
preserve_index=True,
564+
mode="append",
565+
procs_cpu_bound=None,
566+
procs_io_bound=None,
567+
inplace=True,
548568
):
549569
"""
550570
Write a Pandas Dataframe as CSV files on S3
@@ -806,7 +826,7 @@ def _data_to_s3_dataset_writer(dataframe,
806826
for keys, subgroup in dataframe.groupby(partition_cols):
807827
subgroup = subgroup.drop(partition_cols, axis="columns")
808828
if not isinstance(keys, tuple):
809-
keys = (keys,)
829+
keys = (keys, )
810830
subdir = "/".join([f"{name}={val}" for name, val in zip(partition_cols, keys)])
811831
prefix = "/".join([path, subdir])
812832
object_path = Pandas._data_to_s3_object_writer(dataframe=subgroup,

awswrangler/redshift.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,11 @@ def get_connection(self, glue_connection):
117117
conn = self.generate_connection(database=database, host=host, port=int(port), user=user, password=password)
118118
return conn
119119

120-
def write_load_manifest(self, manifest_path: str, objects_paths: List[str], procs_io_bound: Optional[int] = None
121-
) -> Dict[str, List[Dict[str, Union[str, bool, Dict[str, int]]]]]:
120+
def write_load_manifest(
121+
self,
122+
manifest_path: str,
123+
objects_paths: List[str],
124+
procs_io_bound: Optional[int] = None) -> Dict[str, List[Dict[str, Union[str, bool, Dict[str, int]]]]]:
122125
objects_sizes: Dict[str, int] = self._session.s3.get_objects_sizes(objects_paths=objects_paths,
123126
procs_io_bound=procs_io_bound)
124127
manifest: Dict[str, List[Dict[str, Union[str, bool, Dict[str, int]]]]] = {"entries": []}

0 commit comments

Comments
 (0)