Skip to content

Commit 5773b7f

Browse files
perf: iterate over generators when writing datafiles to reduce memory pressure (apache#2671)
# Rationale for this change When writing to partitioned tables, there is a large memory spike when the partitions are computed because we `.combine_chunks()` on the new partitioned arrow tables and we materialize the entire list of partitions before writing data files. This PR switches the partition computation to a generator to avoid materializing all the partitions in memory at once, reducing the memory overhead of writing to partitioned tables. ## Are these changes tested? No new tests. The tests using this method were updated to consume the generator as a list. However, in my personal use case, I am using `pa.total_allocated_bytes()` to determine memory allocation before and after the write and see the following across 5 writes of ~128 MB: | Run | Original Impl (Before Write) | Original Impl (After Write) | Iters (Before Write) | Iters (After Write) | |---|---|---|---|---| | 1 | 29.31 MB | 151.62 MB | 28.38 MB | 30.40 MB | | 2 | 27.74 MB | 151.62 MB | 28.85 MB | 30.36 MB | | 3 | 28.81 MB | 151.62 MB | 28.52 MB | 31.29 MB | | 4 | 28.71 MB | 151.62 MB | 29.27 MB | 30.64 MB | | 5 | 28.60 MB | 151.61 MB | 28.29 MB | 31.11 MB | This scales with the size of the write: if I want to write a 3 GB arrow table to a partitioned table, I need at least 6 GB RAM. ## Are there any user-facing changes? No.
1 parent 8878b2c commit 5773b7f

File tree

2 files changed

+21
-28
lines changed

2 files changed

+21
-28
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2790,30 +2790,26 @@ def _dataframe_to_data_files(
27902790
yield from write_file(
27912791
io=io,
27922792
table_metadata=table_metadata,
2793-
tasks=iter(
2794-
[
2795-
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema)
2796-
for batches in bin_pack_arrow_table(df, target_file_size)
2797-
]
2793+
tasks=(
2794+
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema)
2795+
for batches in bin_pack_arrow_table(df, target_file_size)
27982796
),
27992797
)
28002798
else:
28012799
partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df)
28022800
yield from write_file(
28032801
io=io,
28042802
table_metadata=table_metadata,
2805-
tasks=iter(
2806-
[
2807-
WriteTask(
2808-
write_uuid=write_uuid,
2809-
task_id=next(counter),
2810-
record_batches=batches,
2811-
partition_key=partition.partition_key,
2812-
schema=task_schema,
2813-
)
2814-
for partition in partitions
2815-
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
2816-
]
2803+
tasks=(
2804+
WriteTask(
2805+
write_uuid=write_uuid,
2806+
task_id=next(counter),
2807+
record_batches=batches,
2808+
partition_key=partition.partition_key,
2809+
schema=task_schema,
2810+
)
2811+
for partition in partitions
2812+
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
28172813
),
28182814
)
28192815

@@ -2824,7 +2820,7 @@ class _TablePartition:
28242820
arrow_table_partition: pa.Table
28252821

28262822

2827-
def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]:
2823+
def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> Iterable[_TablePartition]:
28282824
"""Based on the iceberg table partition spec, filter the arrow table into partitions with their keys.
28292825
28302826
Example:
@@ -2852,8 +2848,6 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
28522848

28532849
unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([])
28542850

2855-
table_partitions = []
2856-
# TODO: As a next step, we could also play around with yielding instead of materializing the full list
28572851
for unique_partition in unique_partition_fields.to_pylist():
28582852
partition_key = PartitionKey(
28592853
field_values=[
@@ -2880,12 +2874,11 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
28802874

28812875
# The combine_chunks seems to be counter-intuitive to do, but it actually returns
28822876
# fresh buffers that don't interfere with each other when it is written out to file
2883-
table_partitions.append(
2884-
_TablePartition(partition_key=partition_key, arrow_table_partition=filtered_table.combine_chunks())
2877+
yield _TablePartition(
2878+
partition_key=partition_key,
2879+
arrow_table_partition=filtered_table.combine_chunks(),
28852880
)
28862881

2887-
return table_partitions
2888-
28892882

28902883
def _get_field_from_arrow_table(arrow_table: pa.Table, field_path: str) -> pa.Array:
28912884
"""Get a field from an Arrow table, supporting both literal field names and nested field paths.

tests/io/test_pyarrow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2479,7 +2479,7 @@ def test_partition_for_demo() -> None:
24792479
PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"),
24802480
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"),
24812481
)
2482-
result = _determine_partitions(partition_spec, test_schema, arrow_table)
2482+
result = list(_determine_partitions(partition_spec, test_schema, arrow_table))
24832483
assert {table_partition.partition_key.partition for table_partition in result} == {
24842484
Record(2, 2020),
24852485
Record(100, 2021),
@@ -2518,7 +2518,7 @@ def test_partition_for_nested_field() -> None:
25182518
]
25192519

25202520
arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
2521-
partitions = _determine_partitions(spec, schema, arrow_table)
2521+
partitions = list(_determine_partitions(spec, schema, arrow_table))
25222522
partition_values = {p.partition_key.partition[0] for p in partitions}
25232523

25242524
assert partition_values == {486729, 486730}
@@ -2550,7 +2550,7 @@ def test_partition_for_deep_nested_field() -> None:
25502550
]
25512551

25522552
arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
2553-
partitions = _determine_partitions(spec, schema, arrow_table)
2553+
partitions = list(_determine_partitions(spec, schema, arrow_table))
25542554

25552555
assert len(partitions) == 2 # 2 unique partitions
25562556
partition_values = {p.partition_key.partition[0] for p in partitions}
@@ -2621,7 +2621,7 @@ def test_identity_partition_on_multi_columns() -> None:
26212621
}
26222622
arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
26232623

2624-
result = _determine_partitions(partition_spec, test_schema, arrow_table)
2624+
result = list(_determine_partitions(partition_spec, test_schema, arrow_table))
26252625

26262626
assert {table_partition.partition_key.partition for table_partition in result} == expected
26272627
concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result])

0 commit comments

Comments
 (0)