Skip to content

Commit a01cb9f

Browse files
committed
Collect selected snapshot triggers
1 parent 1b0ad0e commit a01cb9f

File tree

8 files changed

+129
-40
lines changed

8 files changed

+129
-40
lines changed

sqlmesh/core/console.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3833,7 +3833,17 @@ def update_snapshot_evaluation_progress(
38333833
audit_only: bool = False,
38343834
snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None,
38353835
) -> None:
3836-
message = f"Evaluating {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
3836+
message = f"Evaluated {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
3837+
3838+
if snapshot_evaluation_triggers:
3839+
if snapshot_evaluation_triggers.ignore_cron_flag is not None:
3840+
message += f" | ignore_cron_flag={snapshot_evaluation_triggers.ignore_cron_flag}"
3841+
if snapshot_evaluation_triggers.cron_ready is not None:
3842+
message += f" | cron_ready={snapshot_evaluation_triggers.cron_ready}"
3843+
if snapshot_evaluation_triggers.auto_restatement_triggers:
3844+
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}"
3845+
if snapshot_evaluation_triggers.select_snapshot_triggers:
3846+
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
38373847

38383848
if snapshot_evaluation_triggers:
38393849
if snapshot_evaluation_triggers.auto_restatement_triggers:
@@ -3842,7 +3852,7 @@ def update_snapshot_evaluation_progress(
38423852
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
38433853

38443854
if audit_only:
3845-
message = f"Auditing {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
3855+
message = f"Audited {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
38463856

38473857
self._write(message)
38483858

sqlmesh/core/context.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2887,9 +2887,10 @@ def _select_models_for_run(
28872887
dag.add(fqn, model.depends_on)
28882888
model_selector = self._new_selector(models=models, dag=dag)
28892889
result = set(model_selector.expand_model_selections(select_models))
2890-
if not no_auto_upstream:
2891-
result_with_upstream = set(dag.subdag(*result))
2892-
return result, result_with_upstream - result
2890+
if no_auto_upstream:
2891+
return result, set()
2892+
result_with_upstream = set(dag.subdag(*result))
2893+
return result_with_upstream, result_with_upstream - result
28932894

28942895
@cached_property
28952896
def _project_type(self) -> str:

sqlmesh/core/plan/stages.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def _missing_intervals(
516516
snapshots_by_name: t.Dict[str, Snapshot],
517517
deployability_index: DeployabilityIndex,
518518
) -> SnapshotToIntervals:
519-
return merged_missing_intervals(
519+
missing_intervals, _ = merged_missing_intervals(
520520
snapshots=snapshots_by_name.values(),
521521
start=plan.start,
522522
end=plan.end,
@@ -530,6 +530,7 @@ def _missing_intervals(
530530
start_override_per_model=plan.start_override_per_model,
531531
end_override_per_model=plan.end_override_per_model,
532532
)
533+
return missing_intervals
533534

534535
def _get_audit_only_snapshots(
535536
self, new_snapshots: t.Dict[SnapshotId, Snapshot]

sqlmesh/core/scheduler.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def merged_missing_intervals(
112112
ignore_cron: bool = False,
113113
end_bounded: bool = False,
114114
selected_snapshots: t.Optional[t.Set[str]] = None,
115-
) -> SnapshotToIntervals:
115+
) -> t.Tuple[SnapshotToIntervals, t.List[SnapshotId]]:
116116
"""Find the largest contiguous date interval parameters based only on what is missing.
117117
118118
For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found,
@@ -132,8 +132,11 @@ def merged_missing_intervals(
132132
end_bounded: If set to true, the returned intervals will be bounded by the target end date, disregarding lookback,
133133
allow_partials, and other attributes that could cause the intervals to exceed the target end date.
134134
selected_snapshots: A set of snapshot names to run. If not provided, all snapshots will be run.
135+
136+
Returns:
137+
A tuple containing a dict containing all snapshots needing to be run with their associated interval params and a list of snapshots that are ready to run based on their naive cron schedule (ignoring plan/run context and other attributes).
135138
"""
136-
snapshots_to_intervals = merged_missing_intervals(
139+
snapshots_to_intervals, snapshots_naive_cron_ready = merged_missing_intervals(
137140
snapshots=self.snapshot_per_version.values(),
138141
start=start,
139142
end=end,
@@ -151,7 +154,7 @@ def merged_missing_intervals(
151154
snapshots_to_intervals = {
152155
s: i for s, i in snapshots_to_intervals.items() if s.name in selected_snapshots
153156
}
154-
return snapshots_to_intervals
157+
return snapshots_to_intervals, snapshots_naive_cron_ready
155158

156159
def evaluate(
157160
self,
@@ -658,7 +661,7 @@ def _run_or_audit(
658661
{s.name_version: s.next_auto_restatement_ts for s in self.snapshots.values()}
659662
)
660663

661-
merged_intervals = self.merged_missing_intervals(
664+
merged_intervals, snapshots_naive_cron_ready = self.merged_missing_intervals(
662665
start,
663666
end,
664667
execution_time,
@@ -673,9 +676,7 @@ def _run_or_audit(
673676
if not merged_intervals:
674677
return CompletionStatus.NOTHING_TO_DO
675678

676-
merged_intervals_snapshots = {
677-
snapshot.snapshot_id: snapshot for snapshot in merged_intervals.keys()
678-
}
679+
merged_intervals_snapshots = {snapshot.snapshot_id for snapshot in merged_intervals}
679680
select_snapshot_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
680681
if selected_snapshots and selected_snapshots_auto_upstream:
681682
# actually selected snapshots are their own triggers
@@ -691,24 +692,25 @@ def _run_or_audit(
691692
]
692693
}
693694

694-
# trace upstream by reversing dag of all snapshots to evaluate
695-
reversed_intervals_dag = snapshots_to_dag(merged_intervals_snapshots.values()).reversed
696-
for s_id in reversed_intervals_dag:
697-
if s_id not in select_snapshot_triggers:
698-
triggers = []
699-
for parent_s_id in merged_intervals_snapshots[s_id].parents:
700-
triggers.extend(select_snapshot_triggers[parent_s_id])
695+
# trace upstream by walking downstream on reversed dag
696+
reversed_dag = snapshots_to_dag(self.snapshots.values()).reversed
697+
for s_id in reversed_dag:
698+
if s_id in merged_intervals_snapshots:
699+
triggers = select_snapshot_triggers.get(s_id, [])
700+
for parent_s_id in reversed_dag.graph.get(s_id, set()):
701+
triggers.extend(select_snapshot_triggers.get(parent_s_id, []))
701702
select_snapshot_triggers[s_id] = list(dict.fromkeys(triggers))
702703

703704
all_snapshot_triggers: t.Dict[SnapshotId, SnapshotEvaluationTriggers] = {
704705
s_id: SnapshotEvaluationTriggers(
705-
ignore_cron=ignore_cron,
706+
ignore_cron_flag=ignore_cron,
707+
cron_ready=s_id in snapshots_naive_cron_ready,
706708
auto_restatement_triggers=auto_restatement_triggers.get(s_id, []),
707709
select_snapshot_triggers=select_snapshot_triggers.get(s_id, []),
708710
)
709711
for s_id in merged_intervals_snapshots
710-
if ignore_cron or s_id in auto_restatement_triggers or s_id in select_snapshot_triggers
711712
}
713+
712714
errors, _ = self.run_merged_intervals(
713715
merged_intervals=merged_intervals,
714716
deployability_index=deployability_index,
@@ -870,7 +872,7 @@ def merged_missing_intervals(
870872
end_override_per_model: t.Optional[t.Dict[str, datetime]] = None,
871873
ignore_cron: bool = False,
872874
end_bounded: bool = False,
873-
) -> SnapshotToIntervals:
875+
) -> t.Tuple[SnapshotToIntervals, t.List[SnapshotId]]:
874876
"""Find the largest contiguous date interval parameters based only on what is missing.
875877
876878
For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found,
@@ -920,7 +922,7 @@ def compute_interval_params(
920922
end_override_per_model: t.Optional[t.Dict[str, datetime]] = None,
921923
ignore_cron: bool = False,
922924
end_bounded: bool = False,
923-
) -> SnapshotToIntervals:
925+
) -> t.Tuple[SnapshotToIntervals, t.List[SnapshotId]]:
924926
"""Find the largest contiguous date interval parameters based only on what is missing.
925927
926928
For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found,
@@ -942,7 +944,7 @@ def compute_interval_params(
942944
allow_partials, and other attributes that could cause the intervals to exceed the target end date.
943945
944946
Returns:
945-
A dict containing all snapshots needing to be run with their associated interval params.
947+
A tuple containing a dict containing all snapshots needing to be run with their associated interval params and a list of snapshots that are ready to run based on their naive cron schedule (ignoring plan/run context and other attributes).
946948
"""
947949
snapshot_merged_intervals = {}
948950

@@ -970,7 +972,11 @@ def compute_interval_params(
970972
contiguous_batch.append((next_batch[0][0], next_batch[-1][-1]))
971973
snapshot_merged_intervals[snapshot] = contiguous_batch
972974

973-
return snapshot_merged_intervals
975+
snapshots_naive_cron_ready = [
976+
snap.snapshot_id for snap in missing_intervals(snapshots, execution_time=execution_time)
977+
]
978+
979+
return snapshot_merged_intervals, snapshots_naive_cron_ready
974980

975981

976982
def interval_diff(

sqlmesh/core/snapshot/definition.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,11 +326,10 @@ def table_name_for_environment(
326326

327327

328328
class SnapshotEvaluationTriggers(PydanticModel):
329-
ignore_cron: bool
329+
ignore_cron_flag: t.Optional[bool] = None
330+
cron_ready: t.Optional[bool] = None
330331
auto_restatement_triggers: t.List[SnapshotId] = []
331332
select_snapshot_triggers: t.List[SnapshotId] = []
332-
directly_modified_triggers: t.List[SnapshotId] = []
333-
manual_restatement_triggers: t.List[SnapshotId] = []
334333

335334

336335
class SnapshotInfoMixin(ModelKindMixin):

tests/core/test_integration.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,6 +1779,77 @@ def test_select_unchanged_model_for_backfill(init_and_plan_context: t.Callable):
17791779
assert {o.name for o in schema_objects} == {"waiter_revenue_by_day", "top_waiters"}
17801780

17811781

1782+
@time_machine.travel("2023-01-08 00:00:00 UTC")
1783+
def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixture):
1784+
context, plan = init_and_plan_context("examples/sushi")
1785+
context.apply(plan)
1786+
1787+
# add auto restatement to orders
1788+
model = context.get_model("sushi.orders")
1789+
kind = {
1790+
**model.kind.dict(),
1791+
"auto_restatement_cron": "@hourly",
1792+
}
1793+
kwargs = {
1794+
**model.dict(),
1795+
"kind": kind,
1796+
}
1797+
context.upsert_model(PythonModel.parse_obj(kwargs))
1798+
plan = context.plan_builder(skip_tests=True).build()
1799+
context.apply(plan)
1800+
1801+
# Mock run_merged_intervals to capture triggers arg
1802+
scheduler = context.scheduler()
1803+
run_merged_intervals_mock = mocker.patch.object(
1804+
scheduler, "run_merged_intervals", return_value=([], [])
1805+
)
1806+
1807+
# User selects top_waiters and waiter_revenue_by_day, others added as auto-upstream
1808+
selected_models = {"top_waiters", "waiter_revenue_by_day"}
1809+
selected_models_auto_upstream = {"order_items", "orders", "items"}
1810+
selected_snapshots = {
1811+
f'"memory"."sushi"."{model}"' for model in selected_models | selected_models_auto_upstream
1812+
}
1813+
selected_snapshots_auto_upstream = selected_snapshots - {
1814+
f'"memory"."sushi"."{model}"' for model in selected_models
1815+
}
1816+
1817+
with time_machine.travel("2023-01-09 00:00:01 UTC"):
1818+
scheduler.run(
1819+
environment=c.PROD,
1820+
selected_snapshots=selected_snapshots,
1821+
selected_snapshots_auto_upstream=selected_snapshots_auto_upstream,
1822+
start="2023-01-01",
1823+
auto_restatement_enabled=True,
1824+
)
1825+
1826+
assert run_merged_intervals_mock.called
1827+
1828+
actual_triggers = run_merged_intervals_mock.call_args.kwargs["snapshot_evaluation_triggers"]
1829+
1830+
# validate ignore_cron not passed and all model crons ready
1831+
assert all(
1832+
not trigger.ignore_cron_flag and trigger.cron_ready for trigger in actual_triggers.values()
1833+
)
1834+
1835+
for id, trigger in actual_triggers.items():
1836+
# top_waiters is its own trigger, waiter_revenue_by_day is upstream of it, everyone else is upstream of both
1837+
select_triggers = [t.name for t in trigger.select_snapshot_triggers]
1838+
assert (
1839+
select_triggers == ['"memory"."sushi"."top_waiters"']
1840+
if id.name == '"memory"."sushi"."top_waiters"'
1841+
else ['"memory"."sushi"."waiter_revenue_by_day"', '"memory"."sushi"."top_waiters"']
1842+
)
1843+
1844+
# everyone other than items is downstream of orders
1845+
auto_restatement_triggers = [t.name for t in trigger.auto_restatement_triggers]
1846+
assert (
1847+
auto_restatement_triggers == []
1848+
if id.name == '"memory"."sushi"."items"'
1849+
else ['"memory"."sushi"."orders"']
1850+
)
1851+
1852+
17821853
@time_machine.travel("2023-01-08 15:00:00 UTC")
17831854
def test_max_interval_end_per_model_not_applied_when_end_is_provided(
17841855
init_and_plan_context: t.Callable,
@@ -6597,13 +6668,13 @@ def plan_with_output(ctx: Context, environment: str):
65976668
assert "Differences from the `prod` environment" in output.stdout
65986669

65996670
assert (
6600-
"""MODEL (
6601-
name test.a,
6602-
+ owner test,
6603-
kind FULL
6604-
)
6605-
SELECT
6606-
- 5 AS col
6671+
"""MODEL (
6672+
name test.a,
6673+
+ owner test,
6674+
kind FULL
6675+
)
6676+
SELECT
6677+
- 5 AS col
66076678
+ 10 AS col"""
66086679
in output.stdout
66096680
)

tests/core/test_scheduler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _get_batched_missing_intervals(
7676
end: TimeLike,
7777
execution_time: t.Optional[TimeLike] = None,
7878
) -> SnapshotToIntervals:
79-
merged_intervals = scheduler.merged_missing_intervals(start, end, execution_time)
79+
merged_intervals, _ = scheduler.merged_missing_intervals(start, end, execution_time)
8080
return scheduler.batch_intervals(merged_intervals, mocker.Mock(), mocker.Mock())
8181

8282
return _get_batched_missing_intervals
@@ -104,9 +104,10 @@ def test_interval_params_missing(scheduler: Scheduler, sushi_context_fixed_date:
104104

105105
start_ds = "2022-01-01"
106106
end_ds = "2022-03-01"
107-
assert compute_interval_params(
107+
interval_params, _ = compute_interval_params(
108108
sushi_context_fixed_date.snapshots.values(), start=start_ds, end=end_ds
109-
)[waiters] == [
109+
)
110+
assert interval_params[waiters] == [
110111
(to_timestamp(start_ds), to_timestamp("2022-03-02")),
111112
]
112113

web/server/api/endpoints/plan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _get_plan_changes(context: Context, plan: Plan) -> models.PlanChanges:
132132

133133
def _get_plan_backfills(context: Context, plan: Plan) -> t.Dict[str, t.Any]:
134134
"""Get plan backfills"""
135-
merged_intervals = context.scheduler().merged_missing_intervals()
135+
merged_intervals, _ = context.scheduler().merged_missing_intervals()
136136
batches = context.scheduler().batch_intervals(merged_intervals, None, EnvironmentNamingInfo())
137137
tasks = {snapshot.name: len(intervals) for snapshot, intervals in batches.items()}
138138
snapshots = plan.context_diff.snapshots

0 commit comments

Comments
 (0)