Skip to content

Commit 1b0ad0e

Browse files
committed
Add triggers class and selected snapshot triggers
1 parent 9c594b6 commit 1b0ad0e

File tree

5 files changed

+81
-21
lines changed

5 files changed

+81
-21
lines changed

sqlmesh/core/console.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@
3737
SnapshotId,
3838
SnapshotInfoLike,
3939
)
40-
from sqlmesh.core.snapshot.definition import Interval, Intervals, SnapshotTableInfo
40+
from sqlmesh.core.snapshot.definition import (
41+
Interval,
42+
Intervals,
43+
SnapshotTableInfo,
44+
SnapshotEvaluationTriggers,
45+
)
4146
from sqlmesh.core.test import ModelTest
4247
from sqlmesh.utils import rich as srich
4348
from sqlmesh.utils import Verbosity
@@ -428,7 +433,7 @@ def update_snapshot_evaluation_progress(
428433
num_audits_passed: int,
429434
num_audits_failed: int,
430435
audit_only: bool = False,
431-
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
436+
snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None,
432437
) -> None:
433438
"""Updates the snapshot evaluation progress."""
434439

@@ -576,7 +581,7 @@ def update_snapshot_evaluation_progress(
576581
num_audits_passed: int,
577582
num_audits_failed: int,
578583
audit_only: bool = False,
579-
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
584+
snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None,
580585
) -> None:
581586
pass
582587

@@ -1058,7 +1063,7 @@ def update_snapshot_evaluation_progress(
10581063
num_audits_passed: int,
10591064
num_audits_failed: int,
10601065
audit_only: bool = False,
1061-
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
1066+
snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None,
10621067
) -> None:
10631068
"""Update the snapshot evaluation progress."""
10641069
if (
@@ -3656,7 +3661,7 @@ def update_snapshot_evaluation_progress(
36563661
num_audits_passed: int,
36573662
num_audits_failed: int,
36583663
audit_only: bool = False,
3659-
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
3664+
snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None,
36603665
) -> None:
36613666
view_name, loaded_batches = self.evaluation_batch_progress[snapshot.snapshot_id]
36623667

@@ -3826,12 +3831,15 @@ def update_snapshot_evaluation_progress(
38263831
num_audits_passed: int,
38273832
num_audits_failed: int,
38283833
audit_only: bool = False,
3829-
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
3834+
snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None,
38303835
) -> None:
38313836
message = f"Evaluating {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"
38323837

3833-
if auto_restatement_triggers:
3834-
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in auto_restatement_triggers)}"
3838+
if snapshot_evaluation_triggers:
3839+
if snapshot_evaluation_triggers.auto_restatement_triggers:
3840+
message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}"
3841+
if snapshot_evaluation_triggers.select_snapshot_triggers:
3842+
message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}"
38353843

38363844
if audit_only:
38373845
message = f"Auditing {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}"

sqlmesh/core/context.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2296,11 +2296,9 @@ def check_intervals(
22962296
}
22972297

22982298
if select_models:
2299-
selected: t.Collection[str] = self._select_models_for_run(
2300-
select_models, True, snapshots.values()
2301-
)
2299+
selected, _ = self._select_models_for_run(select_models, True, snapshots.values())
23022300
else:
2303-
selected = snapshots.keys()
2301+
selected = t.cast(t.Set[str], snapshots.keys())
23042302

23052303
results = {}
23062304
execution_context = self.execution_context(snapshots=snapshots)
@@ -2450,8 +2448,9 @@ def _run(
24502448
scheduler = self.scheduler(environment=environment)
24512449
snapshots = scheduler.snapshots
24522450

2451+
select_models_auto_upstream = None
24532452
if select_models is not None:
2454-
select_models = self._select_models_for_run(
2453+
select_models, select_models_auto_upstream = self._select_models_for_run(
24552454
select_models, no_auto_upstream, snapshots.values()
24562455
)
24572456

@@ -2463,6 +2462,7 @@ def _run(
24632462
ignore_cron=ignore_cron,
24642463
circuit_breaker=circuit_breaker,
24652464
selected_snapshots=select_models,
2465+
selected_snapshots_auto_upstream=select_models_auto_upstream,
24662466
auto_restatement_enabled=environment.lower() == c.PROD,
24672467
run_environment_statements=True,
24682468
)
@@ -2878,7 +2878,7 @@ def _select_models_for_run(
28782878
select_models: t.Collection[str],
28792879
no_auto_upstream: bool,
28802880
snapshots: t.Collection[Snapshot],
2881-
) -> t.Set[str]:
2881+
) -> t.Tuple[t.Set[str], t.Set[str]]:
28822882
models: UniqueKeyDict[str, Model] = UniqueKeyDict(
28832883
"models", **{s.name: s.model for s in snapshots if s.is_model}
28842884
)
@@ -2888,8 +2888,8 @@ def _select_models_for_run(
28882888
model_selector = self._new_selector(models=models, dag=dag)
28892889
result = set(model_selector.expand_model_selections(select_models))
28902890
if not no_auto_upstream:
2891-
result = set(dag.subdag(*result))
2892-
return result
2891+
result_with_upstream = set(dag.subdag(*result))
2892+
return result, result_with_upstream - result
28932893

28942894
@cached_property
28952895
def _project_type(self) -> str:

sqlmesh/core/scheduler.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sqlmesh.core.snapshot.definition import check_ready_intervals
3030
from sqlmesh.core.snapshot.definition import (
3131
Interval,
32+
SnapshotEvaluationTriggers,
3233
expand_range,
3334
parent_snapshots_by_name,
3435
)
@@ -223,6 +224,7 @@ def run(
223224
ignore_cron: bool = False,
224225
end_bounded: bool = False,
225226
selected_snapshots: t.Optional[t.Set[str]] = None,
227+
selected_snapshots_auto_upstream: t.Optional[t.Set[str]] = None,
226228
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
227229
deployability_index: t.Optional[DeployabilityIndex] = None,
228230
auto_restatement_enabled: bool = False,
@@ -239,6 +241,7 @@ def run(
239241
ignore_cron=ignore_cron,
240242
end_bounded=end_bounded,
241243
selected_snapshots=selected_snapshots,
244+
selected_snapshots_auto_upstream=selected_snapshots_auto_upstream,
242245
circuit_breaker=circuit_breaker,
243246
deployability_index=deployability_index,
244247
auto_restatement_enabled=auto_restatement_enabled,
@@ -374,7 +377,7 @@ def run_merged_intervals(
374377
run_environment_statements: bool = False,
375378
audit_only: bool = False,
376379
restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None,
377-
auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {},
380+
snapshot_evaluation_triggers: t.Dict[SnapshotId, SnapshotEvaluationTriggers] = {},
378381
) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]:
379382
"""Runs precomputed batches of missing intervals.
380383
@@ -477,7 +480,9 @@ def evaluate_node(node: SchedulingUnit) -> None:
477480
evaluation_duration_ms,
478481
num_audits - num_audits_failed,
479482
num_audits_failed,
480-
auto_restatement_triggers=auto_restatement_triggers.get(snapshot.snapshot_id),
483+
snapshot_evaluation_triggers=snapshot_evaluation_triggers.get(
484+
snapshot.snapshot_id
485+
),
481486
)
482487

483488
try:
@@ -588,6 +593,7 @@ def _run_or_audit(
588593
ignore_cron: bool = False,
589594
end_bounded: bool = False,
590595
selected_snapshots: t.Optional[t.Set[str]] = None,
596+
selected_snapshots_auto_upstream: t.Optional[t.Set[str]] = None,
591597
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
592598
deployability_index: t.Optional[DeployabilityIndex] = None,
593599
auto_restatement_enabled: bool = False,
@@ -611,6 +617,7 @@ def _run_or_audit(
611617
end_bounded: If set to true, the evaluated intervals will be bounded by the target end date, disregarding lookback,
612618
allow_partials, and other attributes that could cause the intervals to exceed the target end date.
613619
selected_snapshots: A set of snapshot names to run. If not provided, all snapshots will be run.
620+
selected_snapshots_auto_upstream: The set of selected_snapshots that were automatically added because they're upstream of a selected snapshot.
614621
circuit_breaker: An optional handler which checks if the run should be aborted.
615622
deployability_index: Determines snapshots that are deployable in the context of this render.
616623
auto_restatement_enabled: Whether to enable auto restatements.
@@ -666,6 +673,42 @@ def _run_or_audit(
666673
if not merged_intervals:
667674
return CompletionStatus.NOTHING_TO_DO
668675

676+
merged_intervals_snapshots = {
677+
snapshot.snapshot_id: snapshot for snapshot in merged_intervals.keys()
678+
}
679+
select_snapshot_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}
680+
if selected_snapshots and selected_snapshots_auto_upstream:
681+
# actually selected snapshots are their own triggers
682+
selected_snapshots_no_auto_upstream = (
683+
selected_snapshots - selected_snapshots_auto_upstream
684+
)
685+
select_snapshot_triggers = {
686+
s_id: [s_id]
687+
for s_id in [
688+
snapshot_id
689+
for snapshot_id in merged_intervals_snapshots
690+
if snapshot_id.name in selected_snapshots_no_auto_upstream
691+
]
692+
}
693+
694+
# trace upstream by reversing dag of all snapshots to evaluate
695+
reversed_intervals_dag = snapshots_to_dag(merged_intervals_snapshots.values()).reversed
696+
for s_id in reversed_intervals_dag:
697+
if s_id not in select_snapshot_triggers:
698+
triggers = []
699+
for parent_s_id in merged_intervals_snapshots[s_id].parents:
700+
triggers.extend(select_snapshot_triggers[parent_s_id])
701+
select_snapshot_triggers[s_id] = list(dict.fromkeys(triggers))
702+
703+
all_snapshot_triggers: t.Dict[SnapshotId, SnapshotEvaluationTriggers] = {
704+
s_id: SnapshotEvaluationTriggers(
705+
ignore_cron=ignore_cron,
706+
auto_restatement_triggers=auto_restatement_triggers.get(s_id, []),
707+
select_snapshot_triggers=select_snapshot_triggers.get(s_id, []),
708+
)
709+
for s_id in merged_intervals_snapshots
710+
if ignore_cron or s_id in auto_restatement_triggers or s_id in select_snapshot_triggers
711+
}
669712
errors, _ = self.run_merged_intervals(
670713
merged_intervals=merged_intervals,
671714
deployability_index=deployability_index,
@@ -677,7 +720,7 @@ def _run_or_audit(
677720
run_environment_statements=run_environment_statements,
678721
audit_only=audit_only,
679722
restatements=remove_intervals,
680-
auto_restatement_triggers=auto_restatement_triggers,
723+
snapshot_evaluation_triggers=all_snapshot_triggers,
681724
)
682725

683726
return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS

sqlmesh/core/snapshot/definition.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,14 @@ def table_name_for_environment(
325325
return table
326326

327327

328+
class SnapshotEvaluationTriggers(PydanticModel):
329+
ignore_cron: bool
330+
auto_restatement_triggers: t.List[SnapshotId] = []
331+
select_snapshot_triggers: t.List[SnapshotId] = []
332+
directly_modified_triggers: t.List[SnapshotId] = []
333+
manual_restatement_triggers: t.List[SnapshotId] = []
334+
335+
328336
class SnapshotInfoMixin(ModelKindMixin):
329337
name: str
330338
dev_version_: t.Optional[str]

web/server/console.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from sqlmesh.core.console import TerminalConsole
1010
from sqlmesh.core.environment import EnvironmentNamingInfo
1111
from sqlmesh.core.plan.definition import EvaluatablePlan
12-
from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo, SnapshotId
12+
from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo
13+
from sqlmesh.core.snapshot.definition import SnapshotEvaluationTriggers
1314
from sqlmesh.core.test import ModelTest
1415
from sqlmesh.core.test.result import ModelTextTestResult
1516
from sqlmesh.utils.date import now_timestamp
@@ -142,7 +143,7 @@ def update_snapshot_evaluation_progress(
142143
num_audits_passed: int,
143144
num_audits_failed: int,
144145
audit_only: bool = False,
145-
auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None,
146+
snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None,
146147
) -> None:
147148
if audit_only:
148149
return

0 commit comments

Comments
 (0)