3
3
import logging
4
4
5
5
import pg8000 # type: ignore
6
+ import pyarrow as pa # type: ignore
6
7
7
8
from awswrangler import data_types
8
- from awswrangler .exceptions import (
9
- RedshiftLoadError ,
10
- InvalidDataframeType ,
11
- InvalidRedshiftDiststyle ,
12
- InvalidRedshiftDistkey ,
13
- InvalidRedshiftSortstyle ,
14
- InvalidRedshiftSortkey ,
15
- )
9
+ from awswrangler .exceptions import (RedshiftLoadError , InvalidDataframeType , InvalidRedshiftDiststyle ,
10
+ InvalidRedshiftDistkey , InvalidRedshiftSortstyle , InvalidRedshiftSortkey ,
11
+ InvalidRedshiftPrimaryKeys )
16
12
17
13
logger = logging .getLogger (__name__ )
18
14
@@ -165,6 +161,7 @@ def load_table(dataframe,
165
161
distkey = None ,
166
162
sortstyle = "COMPOUND" ,
167
163
sortkey = None ,
164
+ primary_keys : Optional [List [str ]] = None ,
168
165
mode = "append" ,
169
166
preserve_index = False ,
170
167
cast_columns = None ):
@@ -184,11 +181,14 @@ def load_table(dataframe,
184
181
:param distkey: Specifies a column name or positional number for the distribution key
185
182
:param sortstyle: Sorting can be "COMPOUND" or "INTERLEAVED" (https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html)
186
183
:param sortkey: List of columns to be sorted
187
- :param mode: append or overwrite
184
+ :param primary_keys: Primary keys
185
+ :param mode: append, overwrite or upsert
188
186
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
189
187
:param cast_columns: Dictionary of columns names and Redshift types to be casted. (E.g. {"col name": "INT", "col2 name": "FLOAT"})
190
188
:return: None
191
189
"""
190
+ final_table_name : Optional [str ] = None
191
+ temp_table_name : Optional [str ] = None
192
192
cursor = redshift_conn .cursor ()
193
193
if mode == "overwrite" :
194
194
Redshift ._create_table (cursor = cursor ,
@@ -200,13 +200,27 @@ def load_table(dataframe,
200
200
distkey = distkey ,
201
201
sortstyle = sortstyle ,
202
202
sortkey = sortkey ,
203
+ primary_keys = primary_keys ,
203
204
preserve_index = preserve_index ,
204
205
cast_columns = cast_columns )
206
+ table_name = f"{ schema_name } .{ table_name } "
207
+ elif mode == "upsert" :
208
+ guid : str = pa .compat .guid ()
209
+ temp_table_name = f"temp_redshift_{ guid } "
210
+ final_table_name = table_name
211
+ table_name = temp_table_name
212
+ sql : str = f"CREATE TEMPORARY TABLE { temp_table_name } (LIKE { schema_name } .{ final_table_name } )"
213
+ logger .debug (sql )
214
+ cursor .execute (sql )
215
+ else :
216
+ table_name = f"{ schema_name } .{ table_name } "
217
+
205
218
sql = ("-- AWS DATA WRANGLER\n "
206
- f"COPY { schema_name } . { table_name } FROM '{ manifest_path } '\n "
219
+ f"COPY { table_name } FROM '{ manifest_path } '\n "
207
220
f"IAM_ROLE '{ iam_role } '\n "
208
221
"MANIFEST\n "
209
222
"FORMAT AS PARQUET" )
223
+ logger .debug (sql )
210
224
cursor .execute (sql )
211
225
cursor .execute ("-- AWS DATA WRANGLER\n SELECT pg_last_copy_id() AS query_id" )
212
226
query_id = cursor .fetchall ()[0 ][0 ]
@@ -219,6 +233,23 @@ def load_table(dataframe,
219
233
cursor .close ()
220
234
raise RedshiftLoadError (
221
235
f"Redshift load rollbacked. { num_files_loaded } files counted. { num_files } expected." )
236
+
237
+ if (mode == "upsert" ) and (final_table_name is not None ):
238
+ if not primary_keys :
239
+ primary_keys = Redshift .get_primary_keys (connection = redshift_conn ,
240
+ schema = schema_name ,
241
+ table = final_table_name )
242
+ if not primary_keys :
243
+ raise InvalidRedshiftPrimaryKeys ()
244
+ equals_clause = f"{ final_table_name } .%s = { temp_table_name } .%s"
245
+ join_clause = " AND " .join ([equals_clause % (pk , pk ) for pk in primary_keys ])
246
+ sql = f"DELETE FROM { schema_name } .{ final_table_name } USING { temp_table_name } WHERE { join_clause } "
247
+ logger .debug (sql )
248
+ cursor .execute (sql )
249
+ sql = f"INSERT INTO { schema_name } .{ final_table_name } SELECT * FROM { temp_table_name } "
250
+ logger .debug (sql )
251
+ cursor .execute (sql )
252
+
222
253
redshift_conn .commit ()
223
254
cursor .close ()
224
255
@@ -232,6 +263,7 @@ def _create_table(cursor,
232
263
distkey = None ,
233
264
sortstyle = "COMPOUND" ,
234
265
sortkey = None ,
266
+ primary_keys : List [str ] = None ,
235
267
preserve_index = False ,
236
268
cast_columns = None ):
237
269
"""
@@ -246,6 +278,7 @@ def _create_table(cursor,
246
278
:param distkey: Specifies a column name or positional number for the distribution key
247
279
:param sortstyle: Sorting can be "COMPOUND" or "INTERLEAVED" (https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html)
248
280
:param sortkey: List of columns to be sorted
281
+ :param primary_keys: Primary keys
249
282
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
250
283
:param cast_columns: Dictionary of columns names and Redshift types to be casted. (E.g. {"col name": "INT", "col2 name": "FLOAT"})
251
284
:return: None
@@ -273,22 +306,43 @@ def _create_table(cursor,
273
306
distkey = distkey ,
274
307
sortstyle = sortstyle ,
275
308
sortkey = sortkey )
276
- cols_str = "" .join ([f"{ col [0 ]} { col [1 ]} ,\n " for col in schema ])[:- 2 ]
277
- distkey_str = ""
309
+ cols_str : str = "" .join ([f"{ col [0 ]} { col [1 ]} ,\n " for col in schema ])[:- 2 ]
310
+ primary_keys_str : str = ""
311
+ if primary_keys :
312
+ primary_keys_str = f",\n PRIMARY KEY ({ ', ' .join (primary_keys )} )"
313
+ distkey_str : str = ""
278
314
if distkey and diststyle == "KEY" :
279
315
distkey_str = f"\n DISTKEY({ distkey } )"
280
- sortkey_str = ""
316
+ sortkey_str : str = ""
281
317
if sortkey :
282
318
sortkey_str = f"\n { sortstyle } SORTKEY({ ',' .join (sortkey )} )"
283
319
sql = (f"-- AWS DATA WRANGLER\n "
284
320
f"CREATE TABLE IF NOT EXISTS { schema_name } .{ table_name } (\n "
285
321
f"{ cols_str } "
322
+ f"{ primary_keys_str } "
286
323
f")\n DISTSTYLE { diststyle } "
287
324
f"{ distkey_str } "
288
325
f"{ sortkey_str } " )
289
326
logger .debug (f"Create table query:\n { sql } " )
290
327
cursor .execute (sql )
291
328
329
+ @staticmethod
330
+ def get_primary_keys (connection , schema , table ):
331
+ """
332
+ Get PKs
333
+ :param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
334
+ :param schema: Schema name
335
+ :param table: Redshift table name
336
+ :return: PKs list List[str]
337
+ """
338
+ cursor = connection .cursor ()
339
+ cursor .execute (f"SELECT indexdef FROM pg_indexes WHERE schemaname = '{ schema } ' AND tablename = '{ table } '" )
340
+ result = cursor .fetchall ()[0 ][0 ]
341
+ rfields = result .split ('(' )[1 ].strip (')' ).split (',' )
342
+ fields = [field .strip ().strip ('"' ) for field in rfields ]
343
+ cursor .close ()
344
+ return fields
345
+
292
346
@staticmethod
293
347
def _validate_parameters (schema , diststyle , distkey , sortstyle , sortkey ):
294
348
"""
@@ -347,8 +401,8 @@ def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False, cast_c
347
401
raise InvalidDataframeType (dataframe_type )
348
402
return schema_built
349
403
350
- @ staticmethod
351
- def to_parquet ( sql : str ,
404
+ def to_parquet ( self ,
405
+ sql : str ,
352
406
path : str ,
353
407
iam_role : str ,
354
408
connection : Any ,
@@ -366,8 +420,11 @@ def to_parquet(sql: str,
366
420
path = path if path [- 1 ] == "/" else path + "/"
367
421
cursor : Any = connection .cursor ()
368
422
partition_str : str = ""
423
+ manifest_str : str = ""
369
424
if partition_cols is not None :
370
425
partition_str = f"PARTITION BY ({ ',' .join ([x for x in partition_cols ])} )\n "
426
+ else :
427
+ manifest_str = "\n manifest"
371
428
query : str = f"-- AWS DATA WRANGLER\n " \
372
429
f"UNLOAD ('{ sql } ')\n " \
373
430
f"TO '{ path } '\n " \
@@ -376,7 +433,8 @@ def to_parquet(sql: str,
376
433
f"PARALLEL ON\n " \
377
434
f"ENCRYPTED \n " \
378
435
f"{ partition_str } " \
379
- f"FORMAT PARQUET;"
436
+ f"FORMAT PARQUET" \
437
+ f"{ manifest_str } ;"
380
438
logger .debug (f"query:\n { query } " )
381
439
cursor .execute (query )
382
440
query = "-- AWS DATA WRANGLER\n SELECT pg_last_query_id() AS query_id"
@@ -391,4 +449,8 @@ def to_parquet(sql: str,
391
449
logger .debug (f"paths: { paths } " )
392
450
connection .commit ()
393
451
cursor .close ()
452
+ if manifest_str != "" :
453
+ self ._session .s3 .wait_object_exists (path = f"{ path } manifest" )
454
+ for p in paths :
455
+ self ._session .s3 .wait_object_exists (path = p )
394
456
return paths
0 commit comments