Skip to content

Commit 22c272c

Browse files
committed
Add directly modified and restatement triggers
1 parent 9d3d0eb commit 22c272c

File tree

4 files changed

+95
-28
lines changed

4 files changed

+95
-28
lines changed

sqlmesh/core/console.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3830,12 +3830,10 @@ def update_snapshot_evaluation_progress(
38303830
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}"
38313831
if snapshot_evaluation_triggers.select_snapshot_triggers:
38323832
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
3833-
3834-
if snapshot_evaluation_triggers:
3835-
if snapshot_evaluation_triggers.auto_restatement_triggers:
3836-
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}"
3837-
if snapshot_evaluation_triggers.select_snapshot_triggers:
3838-
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
3833+
if snapshot_evaluation_triggers.directly_modified_triggers:
3834+
message += f" | directly_modified_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.directly_modified_triggers)}"
3835+
if snapshot_evaluation_triggers.restatement_triggers:
3836+
message += f" | restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.restatement_triggers)}"
38393837

38403838
if audit_only:
38413839
message = f"Audited {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"

sqlmesh/core/plan/evaluator.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
SnapshotCreationFailedError,
3838
SnapshotNameVersion,
3939
)
40+
from sqlmesh.core.snapshot.definition import SnapshotEvaluationTriggers
4041
from sqlmesh.utils import to_snake_case
4142
from sqlmesh.core.state_sync import StateSync
4243
from sqlmesh.utils import CorrelationId
@@ -83,6 +84,7 @@ def __init__(
8384
self.default_catalog = default_catalog
8485
self.console = console or get_console()
8586
self._circuit_breaker: t.Optional[t.Callable[[], bool]] = None
87+
self._restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
8688

8789
def evaluate(
8890
self,
@@ -234,6 +236,27 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
234236
self.console.log_success("SKIP: No model batches to execute")
235237
return
236238

239+
directly_modified_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
240+
for parent, children in plan.indirectly_modified_snapshots.items():
241+
parent_id = stage.all_snapshots[parent].snapshot_id
242+
directly_modified_triggers[parent_id] = directly_modified_triggers.get(
243+
parent_id, []
244+
) + [parent_id]
245+
for child in children:
246+
directly_modified_triggers[child] = directly_modified_triggers.get(child, []) + [
247+
parent_id
248+
]
249+
directly_modified_triggers = {
250+
k: list(dict.fromkeys(v)) for k, v in directly_modified_triggers.items()
251+
}
252+
snapshot_evaluation_triggers = {
253+
s_id: SnapshotEvaluationTriggers(
254+
directly_modified_triggers=directly_modified_triggers.get(s_id, []),
255+
restatement_triggers=self._restatement_triggers.get(s_id, []),
256+
)
257+
for s_id in [s.snapshot_id for s in stage.all_snapshots.values()]
258+
}
259+
237260
scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator)
238261
# Convert model name restatements to snapshot ID restatements
239262
restatements_by_snapshot_id = {
@@ -249,6 +272,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
249272
start=plan.start,
250273
end=plan.end,
251274
restatements=restatements_by_snapshot_id,
275+
snapshot_evaluation_triggers=snapshot_evaluation_triggers,
252276
)
253277
if errors:
254278
raise PlanError("Plan application failed.")
@@ -286,13 +310,14 @@ def visit_restatement_stage(
286310
# by forcing dev environments to re-run intervals that changed in prod
287311
#
288312
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
289-
snapshot_intervals_to_restate.update(
313+
restatement_intervals_all_environments, self._restatement_triggers = (
290314
self._restatement_intervals_across_all_environments(
291315
prod_restatements=plan.restatements,
292316
disable_restatement_models=plan.disabled_restatement_models,
293317
loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()},
294318
)
295319
)
320+
snapshot_intervals_to_restate.update(restatement_intervals_all_environments)
296321

297322
self.state_sync.remove_intervals(
298323
snapshot_intervals=list(snapshot_intervals_to_restate),
@@ -415,7 +440,9 @@ def _restatement_intervals_across_all_environments(
415440
prod_restatements: t.Dict[str, Interval],
416441
disable_restatement_models: t.Set[str],
417442
loaded_snapshots: t.Dict[SnapshotId, Snapshot],
418-
) -> t.Set[t.Tuple[SnapshotTableInfo, Interval]]:
443+
) -> t.Tuple[
444+
t.Set[t.Tuple[SnapshotTableInfo, Interval]], t.Dict[SnapshotId, t.List[SnapshotId]]
445+
]:
419446
"""
420447
Given a map of snapshot names + intervals to restate in prod:
421448
- Look up matching snapshots across all environments (match based on name - regardless of version)
@@ -426,14 +453,14 @@ def _restatement_intervals_across_all_environments(
426453
run in those environments causes the intervals to be repopulated
427454
"""
428455
if not prod_restatements:
429-
return set()
456+
return set(), {}
430457

431458
prod_name_versions: t.Set[SnapshotNameVersion] = {
432459
s.name_version for s in loaded_snapshots.values()
433460
}
434461

435462
snapshots_to_restate: t.Dict[SnapshotId, t.Tuple[SnapshotTableInfo, Interval]] = {}
436-
463+
restatement_downstream_ids: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
437464
for env_summary in self.state_sync.get_environments_summary():
438465
# Fetch the full environment object one at a time to avoid loading all environments into memory at once
439466
env = self.state_sync.get_environment(env_summary.name)
@@ -450,10 +477,17 @@ def _restatement_intervals_across_all_environments(
450477
for restatement, intervals in prod_restatements.items():
451478
if restatement not in keyed_snapshots:
452479
continue
480+
481+
downstream = env_dag.downstream(restatement)
482+
if not env.is_dev and restatement not in disable_restatement_models:
483+
restatement_downstream_ids[keyed_snapshots[restatement].snapshot_id] = [
484+
keyed_snapshots[name].snapshot_id
485+
for name in downstream
486+
if name not in disable_restatement_models
487+
]
488+
453489
affected_snapshot_names = [
454-
x
455-
for x in ([restatement] + env_dag.downstream(restatement))
456-
if x not in disable_restatement_models
490+
x for x in ([restatement] + downstream) if x not in disable_restatement_models
457491
]
458492
snapshots_to_restate.update(
459493
{
@@ -464,6 +498,14 @@ def _restatement_intervals_across_all_environments(
464498
}
465499
)
466500

501+
restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {
502+
id: [id] for id in restatement_downstream_ids
503+
}
504+
for parent, children in restatement_downstream_ids.items():
505+
for child in children:
506+
restatement_triggers[child] = restatement_triggers.get(child, []) + [parent]
507+
restatement_triggers = {k: list(dict.fromkeys(v)) for k, v in restatement_triggers.items()}
508+
467509
# for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to
468510
# include the whole time range for that snapshot. This requires a call to state to load the full snapshot record,
469511
# so we only do it if necessary
@@ -499,7 +541,7 @@ def _restatement_intervals_across_all_environments(
499541
)
500542
snapshots_to_restate[full_snapshot_id] = (full_snapshot.table_info, new_intervals)
501543

502-
return set(snapshots_to_restate.values())
544+
return set(snapshots_to_restate.values()), restatement_triggers
503545

504546
def _update_intervals_for_new_snapshots(self, snapshots: t.Collection[Snapshot]) -> None:
505547
snapshots_intervals: t.List[SnapshotIntervals] = []

sqlmesh/core/snapshot/definition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,8 @@ class SnapshotEvaluationTriggers(PydanticModel):
330330
cron_ready: t.Optional[bool] = None
331331
auto_restatement_triggers: t.List[SnapshotId] = []
332332
select_snapshot_triggers: t.List[SnapshotId] = []
333+
directly_modified_triggers: t.List[SnapshotId] = []
334+
restatement_triggers: t.List[SnapshotId] = []
333335

334336

335337
class SnapshotInfoMixin(ModelKindMixin):

tests/core/test_integration.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,25 +1784,44 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt
17841784
context, plan = init_and_plan_context("examples/sushi")
17851785
context.apply(plan)
17861786

1787+
# modify 3 models
1788+
# - 2 breaking changes for testing plan directly modified triggers
1789+
# - 1 adding an auto-restatement for subsequent `run` test
1790+
marketing = context.get_model("sushi.marketing")
1791+
marketing_kwargs = {
1792+
**marketing.dict(),
1793+
"query": d.parse_one(
1794+
f"{marketing.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb"
1795+
),
1796+
}
1797+
context.upsert_model(SqlModel.parse_obj(marketing_kwargs))
1798+
1799+
customers = context.get_model("sushi.customers")
1800+
customers_kwargs = {
1801+
**customers.dict(),
1802+
"query": d.parse_one(
1803+
f"{customers.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb"
1804+
),
1805+
}
1806+
context.upsert_model(SqlModel.parse_obj(customers_kwargs))
1807+
17871808
# add auto restatement to orders
1788-
model = context.get_model("sushi.orders")
1789-
kind = {
1790-
**model.kind.dict(),
1809+
orders = context.get_model("sushi.orders")
1810+
orders_kind = {
1811+
**orders.kind.dict(),
17911812
"auto_restatement_cron": "@hourly",
17921813
}
1793-
kwargs = {
1794-
**model.dict(),
1795-
"kind": kind,
1814+
orders_kwargs = {
1815+
**orders.dict(),
1816+
"kind": orders_kind,
17961817
}
1797-
context.upsert_model(PythonModel.parse_obj(kwargs))
1798-
plan = context.plan_builder(skip_tests=True).build()
1799-
context.apply(plan)
1818+
context.upsert_model(PythonModel.parse_obj(orders_kwargs))
18001819

1801-
# Mock run_merged_intervals to capture triggers arg
1802-
scheduler = context.scheduler()
1803-
run_merged_intervals_mock = mocker.patch.object(
1804-
scheduler, "run_merged_intervals", return_value=([], [])
1805-
)
1820+
# spy = mocker.spy(sqlmesh.core.scheduler, "run_merged_intervals")
1821+
1822+
context.plan(auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full())
1823+
1824+
# assert spy.call_args.kwargs["snapshot_evaluation_triggers"]
18061825

18071826
# User selects top_waiters and waiter_revenue_by_day, others added as auto-upstream
18081827
selected_models = {"top_waiters", "waiter_revenue_by_day"}
@@ -1814,6 +1833,12 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt
18141833
f'"memory"."sushi"."{model}"' for model in selected_models
18151834
}
18161835

1836+
# Mock run_merged_intervals to capture triggers arg
1837+
scheduler = context.scheduler()
1838+
run_merged_intervals_mock = mocker.patch.object(
1839+
scheduler, "run_merged_intervals", return_value=([], [])
1840+
)
1841+
18171842
with time_machine.travel("2023-01-09 00:00:01 UTC"):
18181843
scheduler.run(
18191844
environment=c.PROD,

0 commit comments

Comments
 (0)