Skip to content

Commit b95e792

Browse files
authored
Use a join for upsert deduplication (#1685)
This changes the deduplication logic to use join to duplicate the rows. While the original design wasn't wrong, it is more efficient to push things down into PyArrow to have better multi-threading and no GIL. I did a small benchmark: ```python import time import pyarrow as pa from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError from pyiceberg.schema import Schema from pyiceberg.types import NestedField, StringType, IntegerType def _drop_table(catalog: Catalog, identifier: str) -> None: try: catalog.drop_table(identifier) except NoSuchTableError: pass def test_vo(session_catalog: Catalog): catalog = session_catalog identifier = "default.test_upsert_benchmark" _drop_table(catalog, identifier) schema = Schema( NestedField(1, "idx", IntegerType(), required=True), NestedField(2, "number", IntegerType(), required=True), # Mark City as the identifier field, also known as the primary-key identifier_field_ids=[1], ) tbl = catalog.create_table(identifier, schema=schema) arrow_schema = pa.schema( [ pa.field("idx", pa.int32(), nullable=False), pa.field("number", pa.int32(), nullable=False), ] ) # Write some data df = pa.Table.from_pylist( [ {"idx": idx, "number": idx} for idx in range(1, 100000) ], schema=arrow_schema, ) tbl.append(df) df_upsert = pa.Table.from_pylist( # Overlap [ {"idx": idx, "number": idx} for idx in range(80000, 90000) ]+ # Update [ {"idx": idx, "number": idx + 1} for idx in range(90000, 100000) ] # Insert + [ {"idx": idx, "number": idx} for idx in range(100000, 110000)], schema=arrow_schema, ) start = time.time() tbl.upsert(df_upsert) stop = time.time() print(f"Took {stop-start} seconds") ``` And the result was: ``` Took 2.0412521362304688 seconds on the fd-join branch Took 3.5236432552337646 seconds on lastest main ```
1 parent 68a08b1 commit b95e792

File tree

3 files changed

+67
-39
lines changed

3 files changed

+67
-39
lines changed

pyiceberg/table/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,13 @@ def upsert(
11701170
if upsert_util.has_duplicate_rows(df, join_cols):
11711171
raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed")
11721172

1173+
from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible
1174+
1175+
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
1176+
_check_pyarrow_schema_compatible(
1177+
self.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
1178+
)
1179+
11731180
# get list of rows that exist so we don't have to load the entire target table
11741181
matched_predicate = upsert_util.create_match_filter(df, join_cols)
11751182
matched_iceberg_table = self.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow()

pyiceberg/table/upsert_util.py

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -59,51 +59,30 @@ def get_rows_to_update(source_table: pa.Table, target_table: pa.Table, join_cols
5959
"""
6060
Return a table with rows that need to be updated in the target table based on the join columns.
6161
62-
When a row is matched, an additional scan is done to evaluate the non-key columns to detect if an actual change has occurred.
63-
Only matched rows that have an actual change to a non-key column value will be returned in the final output.
62+
The table is joined on the identifier columns, and then checked if there are any updated rows.
63+
Those are selected and everything is renamed correctly.
6464
"""
6565
all_columns = set(source_table.column_names)
6666
join_cols_set = set(join_cols)
67+
non_key_cols = all_columns - join_cols_set
6768

68-
non_key_cols = list(all_columns - join_cols_set)
69+
if has_duplicate_rows(target_table, join_cols):
70+
raise ValueError("Target table has duplicate rows, aborting upsert")
6971

7072
if len(target_table) == 0:
7173
# When the target table is empty, there is nothing to update :)
7274
return source_table.schema.empty_table()
7375

74-
match_expr = functools.reduce(operator.and_, [pc.field(col).isin(target_table.column(col).to_pylist()) for col in join_cols])
75-
76-
matching_source_rows = source_table.filter(match_expr)
77-
78-
rows_to_update = []
79-
80-
for index in range(matching_source_rows.num_rows):
81-
source_row = matching_source_rows.slice(index, 1)
82-
83-
target_filter = functools.reduce(operator.and_, [pc.field(col) == source_row.column(col)[0].as_py() for col in join_cols])
84-
85-
matching_target_row = target_table.filter(target_filter)
86-
87-
if matching_target_row.num_rows > 0:
88-
needs_update = False
89-
90-
for non_key_col in non_key_cols:
91-
source_value = source_row.column(non_key_col)[0].as_py()
92-
target_value = matching_target_row.column(non_key_col)[0].as_py()
93-
94-
if source_value != target_value:
95-
needs_update = True
96-
break
97-
98-
if needs_update:
99-
rows_to_update.append(source_row)
100-
101-
if rows_to_update:
102-
rows_to_update_table = pa.concat_tables(rows_to_update)
103-
else:
104-
rows_to_update_table = source_table.schema.empty_table()
105-
106-
common_columns = set(source_table.column_names).intersection(set(target_table.column_names))
107-
rows_to_update_table = rows_to_update_table.select(list(common_columns))
108-
109-
return rows_to_update_table
76+
diff_expr = functools.reduce(operator.or_, [pc.field(f"{col}-lhs") != pc.field(f"{col}-rhs") for col in non_key_cols])
77+
78+
return (
79+
source_table
80+
# We already know that the schema is compatible, this is to fix large_ types
81+
.cast(target_table.schema)
82+
.join(target_table, keys=list(join_cols_set), join_type="inner", left_suffix="-lhs", right_suffix="-rhs")
83+
.filter(diff_expr)
84+
.drop_columns([f"{col}-rhs" for col in non_key_cols])
85+
.rename_columns({f"{col}-lhs" if col not in join_cols else col: col for col in source_table.column_names})
86+
# Finally cast to the original schema since it doesn't carry nullability:
87+
# https://github.com/apache/arrow/issues/45557
88+
).cast(target_table.schema)

tests/table/test_upsert.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,48 @@ def test_create_match_filter_single_condition() -> None:
427427
)
428428

429429

430+
def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None:
431+
identifier = "default.test_upsert_with_duplicate_rows_in_table"
432+
433+
_drop_table(catalog, identifier)
434+
schema = Schema(
435+
NestedField(1, "city", StringType(), required=True),
436+
NestedField(2, "inhabitants", IntegerType(), required=True),
437+
# Mark City as the identifier field, also known as the primary-key
438+
identifier_field_ids=[1],
439+
)
440+
441+
tbl = catalog.create_table(identifier, schema=schema)
442+
443+
arrow_schema = pa.schema(
444+
[
445+
pa.field("city", pa.string(), nullable=False),
446+
pa.field("inhabitants", pa.int32(), nullable=False),
447+
]
448+
)
449+
450+
# Write some data
451+
df = pa.Table.from_pylist(
452+
[
453+
{"city": "Drachten", "inhabitants": 45019},
454+
{"city": "Drachten", "inhabitants": 45019},
455+
],
456+
schema=arrow_schema,
457+
)
458+
tbl.append(df)
459+
460+
df = pa.Table.from_pylist(
461+
[
462+
# Will be updated, the inhabitants has been updated
463+
{"city": "Drachten", "inhabitants": 45505},
464+
],
465+
schema=arrow_schema,
466+
)
467+
468+
with pytest.raises(ValueError, match="Target table has duplicate rows, aborting upsert"):
469+
_ = tbl.upsert(df)
470+
471+
430472
def test_upsert_without_identifier_fields(catalog: Catalog) -> None:
431473
identifier = "default.test_upsert_without_identifier_fields"
432474
_drop_table(catalog, identifier)

0 commit comments

Comments
 (0)