Skip to content

Commit 5fb6fcd

Browse files
committed
Apply review suggestions
1 parent cdc3723 commit 5fb6fcd

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

src/zenml/zen_stores/migrations/versions/6e4eb89f632d_unique_run_index.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def upgrade() -> None:
4747
sa.select(
4848
run_table.c.id,
4949
run_table.c.pipeline_id,
50-
run_table.c.created,
5150
)
5251
.where(run_table.c.pipeline_id.is_not(None))
5352
.order_by(run_table.c.pipeline_id, run_table.c.created, run_table.c.id)
@@ -67,27 +66,34 @@ def upgrade() -> None:
6766
run_updates.append({"id_": row.id, "index": index_within_pipeline})
6867
run_counts[pipeline_id] = index_within_pipeline
6968

69+
update_batch_size = 10000
7070
if run_updates:
71-
connection.execute(
71+
update_statement = (
7272
sa.update(run_table)
7373
.where(run_table.c.id == sa.bindparam("id_"))
74-
.values(index=sa.bindparam("index")),
75-
run_updates,
74+
.values(index=sa.bindparam("index"))
7675
)
7776

77+
for start in range(0, len(run_updates), update_batch_size):
78+
batch = run_updates[start : start + update_batch_size]
79+
if batch:
80+
connection.execute(update_statement, batch)
81+
7882
if run_counts:
7983
pipeline_updates = [
8084
{"id_": pipeline_id, "run_count": run_count}
8185
for pipeline_id, run_count in run_counts.items()
8286
]
83-
connection.execute(
87+
update_statement = (
8488
sa.update(pipeline_table)
8589
.where(pipeline_table.c.id == sa.bindparam("id_"))
86-
.values(run_count=sa.bindparam("run_count")),
87-
pipeline_updates,
90+
.values(run_count=sa.bindparam("run_count"))
8891
)
92+
for start in range(0, len(pipeline_updates), update_batch_size):
93+
batch = pipeline_updates[start : start + update_batch_size]
94+
if batch:
95+
connection.execute(update_statement, batch)
8996

90-
# Step 3: Make columns non-nullable
9197
with op.batch_alter_table("pipeline_run", schema=None) as batch_op:
9298
batch_op.alter_column(
9399
"index", existing_type=sa.Integer(), nullable=False

0 commit comments

Comments
 (0)