Skip to content

Commit 71d02c6

Browse files
committed
Add pandas.to_redshift upsert
1 parent f2de54c commit 71d02c6

File tree

14 files changed

+248
-59
lines changed

14 files changed

+248
-59
lines changed

awswrangler/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,7 @@ def __init__(self, service):
2323
self._service = service
2424

2525
def __getattr__(self, name):
26-
return getattr(
27-
getattr(
28-
DynamicInstantiate.__default_session,
29-
self._service
30-
),
31-
name
32-
)
26+
return getattr(getattr(DynamicInstantiate.__default_session, self._service), name)
3327

3428

3529
if importlib.util.find_spec("pyspark"): # type: ignore

awswrangler/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ class InvalidRedshiftSortkey(Exception):
5858
pass
5959

6060

61+
class InvalidRedshiftPrimaryKeys(Exception):
62+
pass
63+
64+
6165
class EmptyDataframe(Exception):
6266
pass
6367

awswrangler/pandas.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,7 @@ def to_redshift(
10961096
distkey: Optional[str] = None,
10971097
sortstyle: str = "COMPOUND",
10981098
sortkey: Optional[str] = None,
1099+
primary_keys: Optional[str] = None,
10991100
preserve_index: bool = False,
11001101
mode: str = "append",
11011102
cast_columns: Optional[Dict[str, str]] = None,
@@ -1113,6 +1114,7 @@ def to_redshift(
11131114
:param distkey: Specifies a column name or positional number for the distribution key
11141115
:param sortstyle: Sorting can be "COMPOUND" or "INTERLEAVED" (https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html)
11151116
:param sortkey: List of columns to be sorted
1117+
:param primary_keys: Primary keys
11161118
:param preserve_index: Should we preserve the Dataframe index?
11171119
:param mode: append, overwrite or upsert
11181120
:param cast_columns: Dictionary of columns names and Redshift types to be casted. (E.g. {"col name": "SMALLINT", "col2 name": "FLOAT4"})
@@ -1159,6 +1161,7 @@ def to_redshift(
11591161
distkey=distkey,
11601162
sortstyle=sortstyle,
11611163
sortkey=sortkey,
1164+
primary_keys=primary_keys,
11621165
mode=mode,
11631166
cast_columns=cast_columns,
11641167
)
@@ -1344,14 +1347,23 @@ def _read_parquet_path(session_primitives: Any,
13441347
:param filters: List of filters to apply, like ``[[('x', '=', 0), ...], ...]``.
13451348
:param procs_cpu_bound: Number of cores used for CPU bound tasks
13461349
"""
1347-
path = path[:-1] if path[-1] == "/" else path
1350+
session = session_primitives.session
1351+
is_file: bool = session.s3.does_object_exists(path=path)
1352+
if is_file is False:
1353+
path = path[:-1] if path[-1] == "/" else path
13481354
procs_cpu_bound = procs_cpu_bound if procs_cpu_bound is not None else session_primitives.procs_cpu_bound if session_primitives.procs_cpu_bound is not None else 1
13491355
use_threads: bool = True if procs_cpu_bound > 1 else False
1350-
fs: S3FileSystem = s3.get_fs(session_primitives=session_primitives)
1351-
fs.invalidate_cache()
1352-
fs = pa.filesystem._ensure_filesystem(fs)
1353-
logger.debug(f"Reading Parquet table: {path}")
1354-
table = pq.read_table(source=path, columns=columns, filters=filters, filesystem=fs, use_threads=use_threads)
1356+
logger.debug(f"Reading Parquet: {path}")
1357+
if is_file is True:
1358+
client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config)
1359+
bucket, key = path.replace("s3://", "").split("/", 1)
1360+
obj = client_s3.get_object(Bucket=bucket, Key=key)
1361+
table = pq.ParquetFile(source=BytesIO(obj["Body"].read())).read(columns=columns, use_threads=use_threads)
1362+
else:
1363+
fs: S3FileSystem = s3.get_fs(session_primitives=session_primitives)
1364+
fs = pa.filesystem._ensure_filesystem(fs)
1365+
fs.invalidate_cache()
1366+
table = pq.read_table(source=path, columns=columns, filters=filters, filesystem=fs, use_threads=use_threads)
13551367
# Check if we lose some integer during the conversion (Happens when has some null value)
13561368
integers = [field.name for field in table.schema if str(field.type).startswith("int")]
13571369
logger.debug(f"Converting to Pandas: {path}")

awswrangler/redshift.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,12 @@
33
import logging
44

55
import pg8000 # type: ignore
6+
import pyarrow as pa # type: ignore
67

78
from awswrangler import data_types
8-
from awswrangler.exceptions import (
9-
RedshiftLoadError,
10-
InvalidDataframeType,
11-
InvalidRedshiftDiststyle,
12-
InvalidRedshiftDistkey,
13-
InvalidRedshiftSortstyle,
14-
InvalidRedshiftSortkey,
15-
)
9+
from awswrangler.exceptions import (RedshiftLoadError, InvalidDataframeType, InvalidRedshiftDiststyle,
10+
InvalidRedshiftDistkey, InvalidRedshiftSortstyle, InvalidRedshiftSortkey,
11+
InvalidRedshiftPrimaryKeys)
1612

1713
logger = logging.getLogger(__name__)
1814

@@ -165,6 +161,7 @@ def load_table(dataframe,
165161
distkey=None,
166162
sortstyle="COMPOUND",
167163
sortkey=None,
164+
primary_keys: Optional[List[str]] = None,
168165
mode="append",
169166
preserve_index=False,
170167
cast_columns=None):
@@ -184,11 +181,14 @@ def load_table(dataframe,
184181
:param distkey: Specifies a column name or positional number for the distribution key
185182
:param sortstyle: Sorting can be "COMPOUND" or "INTERLEAVED" (https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html)
186183
:param sortkey: List of columns to be sorted
187-
:param mode: append or overwrite
184+
:param primary_keys: Primary keys
185+
:param mode: append, overwrite or upsert
188186
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
189187
:param cast_columns: Dictionary of columns names and Redshift types to be casted. (E.g. {"col name": "INT", "col2 name": "FLOAT"})
190188
:return: None
191189
"""
190+
final_table_name: Optional[str] = None
191+
temp_table_name: Optional[str] = None
192192
cursor = redshift_conn.cursor()
193193
if mode == "overwrite":
194194
Redshift._create_table(cursor=cursor,
@@ -200,13 +200,27 @@ def load_table(dataframe,
200200
distkey=distkey,
201201
sortstyle=sortstyle,
202202
sortkey=sortkey,
203+
primary_keys=primary_keys,
203204
preserve_index=preserve_index,
204205
cast_columns=cast_columns)
206+
table_name = f"{schema_name}.{table_name}"
207+
elif mode == "upsert":
208+
guid: str = pa.compat.guid()
209+
temp_table_name = f"temp_redshift_{guid}"
210+
final_table_name = table_name
211+
table_name = temp_table_name
212+
sql: str = f"CREATE TEMPORARY TABLE {temp_table_name} (LIKE {schema_name}.{final_table_name})"
213+
logger.debug(sql)
214+
cursor.execute(sql)
215+
else:
216+
table_name = f"{schema_name}.{table_name}"
217+
205218
sql = ("-- AWS DATA WRANGLER\n"
206-
f"COPY {schema_name}.{table_name} FROM '{manifest_path}'\n"
219+
f"COPY {table_name} FROM '{manifest_path}'\n"
207220
f"IAM_ROLE '{iam_role}'\n"
208221
"MANIFEST\n"
209222
"FORMAT AS PARQUET")
223+
logger.debug(sql)
210224
cursor.execute(sql)
211225
cursor.execute("-- AWS DATA WRANGLER\n SELECT pg_last_copy_id() AS query_id")
212226
query_id = cursor.fetchall()[0][0]
@@ -219,6 +233,23 @@ def load_table(dataframe,
219233
cursor.close()
220234
raise RedshiftLoadError(
221235
f"Redshift load rollbacked. {num_files_loaded} files counted. {num_files} expected.")
236+
237+
if (mode == "upsert") and (final_table_name is not None):
238+
if not primary_keys:
239+
primary_keys = Redshift.get_primary_keys(connection=redshift_conn,
240+
schema=schema_name,
241+
table=final_table_name)
242+
if not primary_keys:
243+
raise InvalidRedshiftPrimaryKeys()
244+
equals_clause = f"{final_table_name}.%s = {temp_table_name}.%s"
245+
join_clause = " AND ".join([equals_clause % (pk, pk) for pk in primary_keys])
246+
sql = f"DELETE FROM {schema_name}.{final_table_name} USING {temp_table_name} WHERE {join_clause}"
247+
logger.debug(sql)
248+
cursor.execute(sql)
249+
sql = f"INSERT INTO {schema_name}.{final_table_name} SELECT * FROM {temp_table_name}"
250+
logger.debug(sql)
251+
cursor.execute(sql)
252+
222253
redshift_conn.commit()
223254
cursor.close()
224255

@@ -232,6 +263,7 @@ def _create_table(cursor,
232263
distkey=None,
233264
sortstyle="COMPOUND",
234265
sortkey=None,
266+
primary_keys: List[str] = None,
235267
preserve_index=False,
236268
cast_columns=None):
237269
"""
@@ -246,6 +278,7 @@ def _create_table(cursor,
246278
:param distkey: Specifies a column name or positional number for the distribution key
247279
:param sortstyle: Sorting can be "COMPOUND" or "INTERLEAVED" (https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html)
248280
:param sortkey: List of columns to be sorted
281+
:param primary_keys: Primary keys
249282
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
250283
:param cast_columns: Dictionary of columns names and Redshift types to be casted. (E.g. {"col name": "INT", "col2 name": "FLOAT"})
251284
:return: None
@@ -273,22 +306,43 @@ def _create_table(cursor,
273306
distkey=distkey,
274307
sortstyle=sortstyle,
275308
sortkey=sortkey)
276-
cols_str = "".join([f"{col[0]} {col[1]},\n" for col in schema])[:-2]
277-
distkey_str = ""
309+
cols_str: str = "".join([f"{col[0]} {col[1]},\n" for col in schema])[:-2]
310+
primary_keys_str: str = ""
311+
if primary_keys:
312+
primary_keys_str = f",\nPRIMARY KEY ({', '.join(primary_keys)})"
313+
distkey_str: str = ""
278314
if distkey and diststyle == "KEY":
279315
distkey_str = f"\nDISTKEY({distkey})"
280-
sortkey_str = ""
316+
sortkey_str: str = ""
281317
if sortkey:
282318
sortkey_str = f"\n{sortstyle} SORTKEY({','.join(sortkey)})"
283319
sql = (f"-- AWS DATA WRANGLER\n"
284320
f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (\n"
285321
f"{cols_str}"
322+
f"{primary_keys_str}"
286323
f")\nDISTSTYLE {diststyle}"
287324
f"{distkey_str}"
288325
f"{sortkey_str}")
289326
logger.debug(f"Create table query:\n{sql}")
290327
cursor.execute(sql)
291328

329+
@staticmethod
330+
def get_primary_keys(connection, schema, table):
331+
"""
332+
Get PKs
333+
:param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
334+
:param schema: Schema name
335+
:param table: Redshift table name
336+
:return: PKs list List[str]
337+
"""
338+
cursor = connection.cursor()
339+
cursor.execute(f"SELECT indexdef FROM pg_indexes WHERE schemaname = '{schema}' AND tablename = '{table}'")
340+
result = cursor.fetchall()[0][0]
341+
rfields = result.split('(')[1].strip(')').split(',')
342+
fields = [field.strip().strip('"') for field in rfields]
343+
cursor.close()
344+
return fields
345+
292346
@staticmethod
293347
def _validate_parameters(schema, diststyle, distkey, sortstyle, sortkey):
294348
"""
@@ -347,8 +401,8 @@ def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False, cast_c
347401
raise InvalidDataframeType(dataframe_type)
348402
return schema_built
349403

350-
@staticmethod
351-
def to_parquet(sql: str,
404+
def to_parquet(self,
405+
sql: str,
352406
path: str,
353407
iam_role: str,
354408
connection: Any,
@@ -366,8 +420,11 @@ def to_parquet(sql: str,
366420
path = path if path[-1] == "/" else path + "/"
367421
cursor: Any = connection.cursor()
368422
partition_str: str = ""
423+
manifest_str: str = ""
369424
if partition_cols is not None:
370425
partition_str = f"PARTITION BY ({','.join([x for x in partition_cols])})\n"
426+
else:
427+
manifest_str = "\nmanifest"
371428
query: str = f"-- AWS DATA WRANGLER\n" \
372429
f"UNLOAD ('{sql}')\n" \
373430
f"TO '{path}'\n" \
@@ -376,7 +433,8 @@ def to_parquet(sql: str,
376433
f"PARALLEL ON\n" \
377434
f"ENCRYPTED \n" \
378435
f"{partition_str}" \
379-
f"FORMAT PARQUET;"
436+
f"FORMAT PARQUET" \
437+
f"{manifest_str};"
380438
logger.debug(f"query:\n{query}")
381439
cursor.execute(query)
382440
query = "-- AWS DATA WRANGLER\nSELECT pg_last_query_id() AS query_id"
@@ -391,4 +449,8 @@ def to_parquet(sql: str,
391449
logger.debug(f"paths: {paths}")
392450
connection.commit()
393451
cursor.close()
452+
if manifest_str != "":
453+
self._session.s3.wait_object_exists(path=f"{path}manifest")
454+
for p in paths:
455+
self._session.s3.wait_object_exists(path=p)
394456
return paths

awswrangler/s3.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Dict, List, Optional
1+
from typing import Dict, List, Optional, Tuple
22
import multiprocessing as mp
33
from math import ceil
44
import logging
5+
from time import sleep
56

67
from botocore.exceptions import ClientError, HTTPClientError # type: ignore
78
import s3fs # type: ignore
@@ -21,7 +22,7 @@ def mkdir_if_not_exists(fs, path):
2122

2223

2324
def get_fs(session_primitives=None):
24-
aws_access_key_id, aws_secret_access_key, profile_name, config, s3_additional_kwargs = None, None, None, None, None
25+
aws_access_key_id, aws_secret_access_key, profile_name = None, None, None
2526
args = {}
2627

2728
if session_primitives is not None:
@@ -42,17 +43,49 @@ def get_fs(session_primitives=None):
4243
args["key"] = aws_access_key_id,
4344
args["secret"] = aws_secret_access_key
4445

46+
args["default_cache_type"] = "none"
47+
args["default_fill_cache"] = False
4548
fs = s3fs.S3FileSystem(**args)
46-
fs.invalidate_cache(path=None)
4749
return fs
4850

4951

5052
class S3:
5153
def __init__(self, session):
5254
self._session = session
55+
self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config)
56+
57+
def does_object_exists(self, path: str) -> bool:
58+
"""
59+
Check if object exists on S3
60+
61+
:param path: S3 path (e.g. s3://...)
62+
:return: boolean
63+
"""
64+
bucket: str
65+
key: str
66+
bucket, key = path.replace("s3://", "").split("/", 1)
67+
try:
68+
self._client_s3.head_object(Bucket=bucket, Key=key)
69+
return True
70+
except ClientError as ex:
71+
if ex.response["ResponseMetadata"]["HTTPStatusCode"] == 404:
72+
return False
73+
raise ex
74+
75+
def wait_object_exists(self, path: str, polling_sleep: float = 0.1) -> None:
76+
"""
77+
Wait object exists on S3
78+
79+
:param path: S3 path (e.g. s3://...)
80+
:param polling_sleep: Milliseconds
81+
:return: None
82+
"""
83+
while self.does_object_exists(path=path) is False:
84+
sleep(polling_sleep)
5385

5486
@staticmethod
55-
def parse_path(path):
87+
def parse_path(path: str) -> Tuple[str, str]:
88+
bucket: str
5689
bucket, path = path.replace("s3://", "").split("/", 1)
5790
if not path:
5891
path = ""

awswrangler/sagemaker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ def get_job_outputs(self, path: str) -> Any:
2323
if key.split("/")[-1] != "model.tar.gz":
2424
key = f"{key}/model.tar.gz"
2525
body = self._client_s3.get_object(Bucket=bucket, Key=key)["Body"].read()
26-
body = tarfile.io.BytesIO(body)
26+
body = tarfile.io.BytesIO(body) # type: ignore
2727
tar = tarfile.open(fileobj=body)
2828

2929
results = []
3030
for member in tar.getmembers():
3131
f = tar.extractfile(member)
3232
file_type = member.name.split(".")[-1]
3333

34-
if file_type == "pkl":
34+
if (file_type == "pkl") and (f is not None):
3535
f = pickle.load(f)
3636

3737
results.append(f)

awswrangler/session.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from awswrangler.emr import EMR
1616
from awswrangler.sagemaker import SageMaker
1717

18-
1918
PYSPARK_INSTALLED = False
2019
if importlib.util.find_spec("pyspark"): # type: ignore
2120
PYSPARK_INSTALLED = True

docs/source/api/awswrangler.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Submodules
1515
awswrangler.pandas
1616
awswrangler.redshift
1717
awswrangler.s3
18+
awswrangler.sagemaker
1819
awswrangler.session
1920
awswrangler.spark
2021
awswrangler.utils
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
awswrangler.sagemaker module
2+
============================
3+
4+
.. automodule:: awswrangler.sagemaker
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:

0 commit comments

Comments
 (0)