Skip to content

Commit 0398b84

Browse files
committed
add S3.copy_listed_objects feature.
1 parent c2eed6e commit 0398b84

File tree

2 files changed

+142
-4
lines changed

2 files changed

+142
-4
lines changed

awswrangler/s3.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ def parse_path(path):
5757
path += "/"
5858
return bucket, path
5959

60+
@staticmethod
61+
def parse_object_path(path):
62+
return path.replace("s3://", "").split("/", 1)
63+
6064
def delete_objects(self, path):
6165
bucket, path = self.parse_path(path=path)
6266
client = self._session.boto3_session.client(
@@ -127,11 +131,13 @@ def delete_listed_objects(self, objects_paths, procs_io_bound=None):
127131
proc.daemon = False
128132
proc.start()
129133
procs.append(proc)
134+
for proc in procs:
135+
proc.join()
130136
else:
131-
self.delete_objects_batch(self._session.primitives, bucket,
132-
batch)
133-
for proc in procs:
134-
proc.join()
137+
self.delete_objects_batch(
138+
session_primitives=self._session.primitives,
139+
bucket=bucket,
140+
batch=batch)
135141

136142
def delete_not_listed_objects(self, objects_paths, procs_io_bound=None):
137143
if not procs_io_bound:
@@ -274,3 +280,75 @@ def get_objects_sizes(self, objects_paths, procs_io_bound=None):
274280
LOGGER.debug(f"Closing proc number: {i}")
275281
receive_pipes[i].close()
276282
return objects_sizes
283+
284+
def copy_listed_objects(self,
285+
objects_paths,
286+
source_path,
287+
target_path,
288+
mode="append",
289+
procs_io_bound=None):
290+
if not procs_io_bound:
291+
procs_io_bound = self._session.procs_io_bound
292+
LOGGER.debug(f"procs_io_bound: {procs_io_bound}")
293+
LOGGER.debug(f"len(objects_paths): {len(objects_paths)}")
294+
if source_path[-1] == "/":
295+
source_path = source_path[:-1]
296+
if target_path[-1] == "/":
297+
target_path = target_path[:-1]
298+
299+
if mode == "overwrite":
300+
LOGGER.debug(f"Deleting to overwrite: {target_path}")
301+
self._session.s3.delete_objects(path=target_path)
302+
elif mode == "overwrite_partitions":
303+
objects_wo_prefix = [
304+
o.replace(f"{source_path}/", "") for o in objects_paths
305+
]
306+
objects_wo_filename = [
307+
f"{o.rpartition('/')[0]}/" for o in objects_wo_prefix
308+
]
309+
partitions_paths = list(set(objects_wo_filename))
310+
target_partitions_paths = [
311+
f"{target_path}/{p}" for p in partitions_paths
312+
]
313+
for path in target_partitions_paths:
314+
LOGGER.debug(f"Deleting to overwrite_partitions: {path}")
315+
self._session.s3.delete_objects(path=path)
316+
317+
batch = []
318+
for obj in objects_paths:
319+
object_wo_prefix = obj.replace(f"{source_path}/", "")
320+
target_object = f"{target_path}/{object_wo_prefix}"
321+
batch.append((obj, target_object))
322+
323+
if procs_io_bound > 1:
324+
bounders = calculate_bounders(len(objects_paths), procs_io_bound)
325+
LOGGER.debug(f"bounders: {bounders}")
326+
procs = []
327+
for bounder in bounders:
328+
proc = mp.Process(
329+
target=self.copy_objects_batch,
330+
args=(
331+
self._session.primitives,
332+
batch[bounder[0]:bounder[1]],
333+
),
334+
)
335+
proc.daemon = False
336+
proc.start()
337+
procs.append(proc)
338+
for proc in procs:
339+
proc.join()
340+
else:
341+
self.copy_objects_batch(
342+
session_primitives=self._session.primitives, batch=batch)
343+
344+
@staticmethod
345+
def copy_objects_batch(session_primitives, batch):
346+
session = session_primitives.session
347+
resource = session.boto3_session.resource(
348+
service_name="s3", config=session.botocore_config)
349+
LOGGER.debug(f"len(batch): {len(batch)}")
350+
for source_obj, target_obj in batch:
351+
source_bucket, source_key = S3.parse_object_path(path=source_obj)
352+
copy_source = {"Bucket": source_bucket, "Key": source_key}
353+
target_bucket, target_key = S3.parse_object_path(path=target_obj)
354+
resource.meta.client.copy(copy_source, target_bucket, target_key)

testing/test_awswrangler/test_s3.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
import boto3
7+
import pandas
78

89
from awswrangler import Session
910

@@ -81,6 +82,15 @@ def bucket(session, cloudformation_outputs):
8182
session.s3.delete_objects(path=f"s3://{bucket}/")
8283

8384

85+
@pytest.fixture(scope="module")
86+
def database(cloudformation_outputs):
87+
if "GlueDatabaseName" in cloudformation_outputs:
88+
database = cloudformation_outputs.get("GlueDatabaseName")
89+
else:
90+
raise Exception("You must deploy the test infrastructure using SAM!")
91+
yield database
92+
93+
8494
@pytest.mark.parametrize("objects_num", [1, 10, 1001, 2001, 3001])
8595
def test_delete_objects(session, bucket, objects_num):
8696
print("Starting writes...")
@@ -147,3 +157,53 @@ def test_get_objects_sizes(session, bucket, objects_num):
147157
session.s3.delete_objects(path=path)
148158
for _, object_size in objects_sizes.items():
149159
assert object_size == 10
160+
161+
162+
@pytest.mark.parametrize("mode, procs_io_bound", [
163+
("append", 1),
164+
("overwrite", 1),
165+
("overwrite_partitions", 1),
166+
("append", 8),
167+
("overwrite", 8),
168+
("overwrite_partitions", 8),
169+
])
170+
def test_copy_listed_objects(session, bucket, database, mode, procs_io_bound):
171+
path0 = f"s3://{bucket}/test_move_objects_0/"
172+
path1 = f"s3://{bucket}/test_move_objects_1/"
173+
print("Starting deletes...")
174+
session.s3.delete_objects(path=path0)
175+
session.s3.delete_objects(path=path1)
176+
dataframe = pandas.read_csv("data_samples/micro.csv")
177+
print("Starting writing path0...")
178+
session.pandas.to_parquet(
179+
dataframe=dataframe,
180+
database=database,
181+
path=path0,
182+
preserve_index=False,
183+
mode="overwrite",
184+
partition_cols=["name", "date"],
185+
)
186+
print("Starting writing path0...")
187+
objects_paths = session.pandas.to_parquet(
188+
dataframe=dataframe,
189+
path=path1,
190+
preserve_index=False,
191+
mode="overwrite",
192+
partition_cols=["name", "date"],
193+
)
194+
print("Starting move...")
195+
session.s3.copy_listed_objects(
196+
objects_paths=objects_paths,
197+
source_path=path1,
198+
target_path=path0,
199+
mode=mode,
200+
procs_io_bound=procs_io_bound,
201+
)
202+
print("Asserting...")
203+
sleep(1)
204+
dataframe2 = session.pandas.read_sql_athena(
205+
sql="select * from test_move_objects_0", database=database)
206+
if mode == "append":
207+
assert 2 * len(dataframe.index) == len(dataframe2.index)
208+
else:
209+
assert len(dataframe.index) == len(dataframe2.index)

0 commit comments

Comments
 (0)