@@ -71,7 +71,7 @@ def to_redshift(
71
71
72
72
:param dataframe: Pandas Dataframe
73
73
:param path: S3 path to write temporary files (E.g. s3://BUCKET_NAME/ANY_NAME/)
74
- :param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
74
+ :param connection: Glue connection name (str) OR a PEP 249 compatible connection (Can be generated with Redshift.generate_connection())
75
75
:param schema: The Redshift Schema for the table
76
76
:param table: The name of the desired Redshift table
77
77
:param iam_role: AWS IAM role with the related permissions
@@ -93,68 +93,83 @@ def to_redshift(
93
93
dataframe .cache ()
94
94
num_rows : int = dataframe .count ()
95
95
logger .info (f"Number of rows: { num_rows } " )
96
- num_partitions : int
97
- if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE :
98
- num_partitions = 1
99
- else :
100
- num_slices : int = self ._session .redshift .get_number_of_slices (redshift_conn = connection )
101
- logger .debug (f"Number of slices on Redshift: { num_slices } " )
102
- num_partitions = num_slices
103
- while num_partitions < min_num_partitions :
104
- num_partitions += num_slices
105
- logger .debug (f"Number of partitions calculated: { num_partitions } " )
106
- spark .conf .set ("spark.sql.execution.arrow.enabled" , "true" )
107
- session_primitives = self ._session .primitives
108
- par_col_name : str = "aws_data_wrangler_internal_partition_id"
109
96
110
- @pandas_udf (returnType = "objects_paths string" , functionType = PandasUDFType .GROUPED_MAP )
111
- def write (pandas_dataframe : pd .DataFrame ) -> pd .DataFrame :
112
- # Exporting ARROW_PRE_0_15_IPC_FORMAT environment variable for
113
- # a temporary workaround while waiting for Apache Arrow updates
114
- # https://stackoverflow.com/questions/58273063/pandasudf-and-pyarrow-0-15-0
115
- os .environ ["ARROW_PRE_0_15_IPC_FORMAT" ] = "1"
97
+ generated_conn : bool = False
98
+ if type (connection ) == str :
99
+ logger .debug ("Glue connection (str) provided." )
100
+ connection = self ._session .glue .get_connection (name = connection )
101
+ generated_conn = True
116
102
117
- del pandas_dataframe [par_col_name ]
118
- paths : List [str ] = session_primitives .session .pandas .to_parquet (dataframe = pandas_dataframe ,
119
- path = path ,
120
- preserve_index = False ,
121
- mode = "append" ,
122
- procs_cpu_bound = 1 ,
123
- procs_io_bound = 1 ,
124
- cast_columns = casts )
125
- return pd .DataFrame .from_dict ({"objects_paths" : paths })
103
+ try :
104
+ num_partitions : int
105
+ if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE :
106
+ num_partitions = 1
107
+ else :
108
+ num_slices : int = self ._session .redshift .get_number_of_slices (redshift_conn = connection )
109
+ logger .debug (f"Number of slices on Redshift: { num_slices } " )
110
+ num_partitions = num_slices
111
+ while num_partitions < min_num_partitions :
112
+ num_partitions += num_slices
113
+ logger .debug (f"Number of partitions calculated: { num_partitions } " )
114
+ spark .conf .set ("spark.sql.execution.arrow.enabled" , "true" )
115
+ session_primitives = self ._session .primitives
116
+ par_col_name : str = "aws_data_wrangler_internal_partition_id"
126
117
127
- df_objects_paths : DataFrame = dataframe .repartition (numPartitions = num_partitions ) # type: ignore
128
- df_objects_paths : DataFrame = df_objects_paths .withColumn (par_col_name , spark_partition_id ()) # type: ignore
129
- df_objects_paths : DataFrame = df_objects_paths .groupby (par_col_name ).apply (write ) # type: ignore
118
+ @pandas_udf (returnType = "objects_paths string" , functionType = PandasUDFType .GROUPED_MAP )
119
+ def write (pandas_dataframe : pd .DataFrame ) -> pd .DataFrame :
120
+ # Exporting ARROW_PRE_0_15_IPC_FORMAT environment variable for
121
+ # a temporary workaround while waiting for Apache Arrow updates
122
+ # https://stackoverflow.com/questions/58273063/pandasudf-and-pyarrow-0-15-0
123
+ os .environ ["ARROW_PRE_0_15_IPC_FORMAT" ] = "1"
130
124
131
- objects_paths : List [str ] = list (df_objects_paths .toPandas ()["objects_paths" ])
132
- dataframe .unpersist ()
133
- num_files_returned : int = len (objects_paths )
134
- if num_files_returned != num_partitions :
135
- raise MissingBatchDetected (f"{ num_files_returned } files returned. { num_partitions } expected." )
136
- logger .debug (f"List of objects returned: { objects_paths } " )
137
- logger .debug (f"Number of objects returned from UDF: { num_files_returned } " )
138
- manifest_path : str = f"{ path } manifest.json"
139
- self ._session .redshift .write_load_manifest (manifest_path = manifest_path ,
140
- objects_paths = objects_paths ,
141
- procs_io_bound = self ._procs_io_bound )
142
- self ._session .redshift .load_table (dataframe = dataframe ,
143
- dataframe_type = "spark" ,
144
- manifest_path = manifest_path ,
145
- schema_name = schema ,
146
- table_name = table ,
147
- redshift_conn = connection ,
148
- preserve_index = False ,
149
- num_files = num_partitions ,
150
- iam_role = iam_role ,
151
- diststyle = diststyle ,
152
- distkey = distkey ,
153
- sortstyle = sortstyle ,
154
- sortkey = sortkey ,
155
- mode = mode ,
156
- cast_columns = casts )
157
- self ._session .s3 .delete_objects (path = path , procs_io_bound = self ._procs_io_bound )
125
+ del pandas_dataframe [par_col_name ]
126
+ paths : List [str ] = session_primitives .session .pandas .to_parquet (dataframe = pandas_dataframe ,
127
+ path = path ,
128
+ preserve_index = False ,
129
+ mode = "append" ,
130
+ procs_cpu_bound = 1 ,
131
+ procs_io_bound = 1 ,
132
+ cast_columns = casts )
133
+ return pd .DataFrame .from_dict ({"objects_paths" : paths })
134
+
135
+ df_objects_paths : DataFrame = dataframe .repartition (numPartitions = num_partitions ) # type: ignore
136
+ df_objects_paths = df_objects_paths .withColumn (par_col_name , spark_partition_id ()) # type: ignore
137
+ df_objects_paths = df_objects_paths .groupby (par_col_name ).apply (write ) # type: ignore
138
+
139
+ objects_paths : List [str ] = list (df_objects_paths .toPandas ()["objects_paths" ])
140
+ dataframe .unpersist ()
141
+ num_files_returned : int = len (objects_paths )
142
+ if num_files_returned != num_partitions :
143
+ raise MissingBatchDetected (f"{ num_files_returned } files returned. { num_partitions } expected." )
144
+ logger .debug (f"List of objects returned: { objects_paths } " )
145
+ logger .debug (f"Number of objects returned from UDF: { num_files_returned } " )
146
+ manifest_path : str = f"{ path } manifest.json"
147
+ self ._session .redshift .write_load_manifest (manifest_path = manifest_path ,
148
+ objects_paths = objects_paths ,
149
+ procs_io_bound = self ._procs_io_bound )
150
+ self ._session .redshift .load_table (dataframe = dataframe ,
151
+ dataframe_type = "spark" ,
152
+ manifest_path = manifest_path ,
153
+ schema_name = schema ,
154
+ table_name = table ,
155
+ redshift_conn = connection ,
156
+ preserve_index = False ,
157
+ num_files = num_partitions ,
158
+ iam_role = iam_role ,
159
+ diststyle = diststyle ,
160
+ distkey = distkey ,
161
+ sortstyle = sortstyle ,
162
+ sortkey = sortkey ,
163
+ mode = mode ,
164
+ cast_columns = casts )
165
+ self ._session .s3 .delete_objects (path = path , procs_io_bound = self ._procs_io_bound )
166
+ except Exception as ex :
167
+ connection .rollback ()
168
+ if generated_conn is True :
169
+ connection .close ()
170
+ raise ex
171
+ if generated_conn is True :
172
+ connection .close ()
158
173
159
174
def create_glue_table (self ,
160
175
database ,
0 commit comments