14
14
import pyarrow as pa # type: ignore
15
15
from pyarrow import parquet as pq # type: ignore
16
16
import tenacity # type: ignore
17
+ from s3fs import S3FileSystem # type: ignore
17
18
18
19
from awswrangler import data_types
19
20
from awswrangler .exceptions import (UnsupportedWriteMode , UnsupportedFileFormat , AthenaQueryError , EmptyS3Object ,
@@ -491,13 +492,13 @@ def _get_query_dtype(self, query_execution_id: str) -> Tuple[Dict[str, str], Lis
491
492
return dtype , parse_timestamps , parse_dates , converters
492
493
493
494
def read_sql_athena (self ,
494
- sql ,
495
- database = None ,
496
- s3_output = None ,
497
- max_result_size = None ,
498
- workgroup = None ,
499
- encryption = None ,
500
- kms_key = None ):
495
+ sql : str ,
496
+ database : Optional [ str ] = None ,
497
+ s3_output : Optional [ str ] = None ,
498
+ max_result_size : Optional [ int ] = None ,
499
+ workgroup : Optional [ str ] = None ,
500
+ encryption : Optional [ str ] = None ,
501
+ kms_key : Optional [ str ] = None ):
501
502
"""
502
503
Executes any SQL query on AWS Athena and return a Dataframe of the result.
503
504
P.S. If max_result_size is passed, then a iterator of Dataframes is returned.
@@ -512,18 +513,21 @@ def read_sql_athena(self,
512
513
:param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
513
514
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
514
515
"""
515
- if not s3_output :
516
- s3_output = self ._session .athena .create_athena_bucket ()
517
- query_execution_id = self ._session .athena .run_query (query = sql ,
518
- database = database ,
519
- s3_output = s3_output ,
520
- workgroup = workgroup ,
521
- encryption = encryption ,
522
- kms_key = kms_key )
523
- query_response = self ._session .athena .wait_query (query_execution_id = query_execution_id )
516
+ if s3_output is None :
517
+ if self ._session .athena_s3_output is not None :
518
+ s3_output = self ._session .athena_s3_output
519
+ else :
520
+ s3_output = self ._session .athena .create_athena_bucket ()
521
+ query_execution_id : str = self ._session .athena .run_query (query = sql ,
522
+ database = database ,
523
+ s3_output = s3_output ,
524
+ workgroup = workgroup ,
525
+ encryption = encryption ,
526
+ kms_key = kms_key )
527
+ query_response : Dict = self ._session .athena .wait_query (query_execution_id = query_execution_id )
524
528
if query_response ["QueryExecution" ]["Status" ]["State" ] in ["FAILED" , "CANCELLED" ]:
525
- reason = query_response ["QueryExecution" ]["Status" ]["StateChangeReason" ]
526
- message_error = f"Query error: { reason } "
529
+ reason : str = query_response ["QueryExecution" ]["Status" ]["StateChangeReason" ]
530
+ message_error : str = f"Query error: { reason } "
527
531
raise AthenaQueryError (message_error )
528
532
else :
529
533
dtype , parse_timestamps , parse_dates , converters = self ._get_query_dtype (
@@ -1133,7 +1137,7 @@ def read_parquet(self,
1133
1137
path : str ,
1134
1138
columns : Optional [List [str ]] = None ,
1135
1139
filters : Optional [Union [List [Tuple [Any ]], List [Tuple [Any ]]]] = None ,
1136
- procs_cpu_bound : Optional [int ] = None ):
1140
+ procs_cpu_bound : Optional [int ] = None ) -> pd . DataFrame :
1137
1141
"""
1138
1142
Read parquet data from S3
1139
1143
@@ -1145,7 +1149,7 @@ def read_parquet(self,
1145
1149
path = path [:- 1 ] if path [- 1 ] == "/" else path
1146
1150
procs_cpu_bound = 1 if self ._session .procs_cpu_bound is None else self ._session .procs_cpu_bound if procs_cpu_bound is None else procs_cpu_bound
1147
1151
use_threads : bool = True if procs_cpu_bound > 1 else False
1148
- fs = s3 .get_fs (session_primitives = self ._session .primitives )
1152
+ fs : S3FileSystem = s3 .get_fs (session_primitives = self ._session .primitives )
1149
1153
fs = pa .filesystem ._ensure_filesystem (fs )
1150
1154
return pq .read_table (source = path , columns = columns , filters = filters ,
1151
1155
filesystem = fs ).to_pandas (use_threads = use_threads )
0 commit comments