From c2994c7b11d6fe564cbdcddb8dbc27fa994b8e02 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Wed, 16 Jul 2025 15:52:07 -0500 Subject: [PATCH 01/13] Print auto-restated trigger of model evaluation in debug console --- sqlmesh/core/console.py | 8 +++++ sqlmesh/core/scheduler.py | 8 ++++- sqlmesh/core/snapshot/definition.py | 51 ++++++++++++++++++++--------- tests/core/test_snapshot.py | 4 +-- web/server/console.py | 3 +- 5 files changed, 55 insertions(+), 19 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 43283ead90..33042cbec1 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -428,6 +428,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + auto_restatement_trigger: t.Optional[SnapshotId] = None, ) -> None: """Updates the snapshot evaluation progress.""" @@ -575,6 +576,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + auto_restatement_trigger: t.Optional[SnapshotId] = None, ) -> None: pass @@ -1056,6 +1058,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + auto_restatement_trigger: t.Optional[SnapshotId] = None, ) -> None: """Update the snapshot evaluation progress.""" if ( @@ -3639,6 +3642,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + auto_restatement_trigger: t.Optional[SnapshotId] = None, ) -> None: view_name, loaded_batches = self.evaluation_batch_progress[snapshot.snapshot_id] @@ -3808,9 +3812,13 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + auto_restatement_trigger: t.Optional[SnapshotId] = None, ) -> None: message = f"Evaluating {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" + if auto_restatement_trigger: + message += f" | evaluation_triggered_by={auto_restatement_trigger.name}" + if audit_only: message = f"Auditing {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index e787e57a23..cb63e3df83 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -415,6 +415,7 @@ def run_merged_intervals( selected_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, run_environment_statements: bool = False, audit_only: bool = False, + auto_restatement_triggers: t.Dict[SnapshotId, SnapshotId] = {}, ) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]: """Runs precomputed batches of missing intervals. @@ -531,6 +532,7 @@ def run_node(node: SchedulingUnit) -> None: evaluation_duration_ms, num_audits - num_audits_failed, num_audits_failed, + auto_restatement_trigger=auto_restatement_triggers.get(snapshot.snapshot_id), ) elif isinstance(node, CreateNode): self.snapshot_evaluator.create_snapshot( @@ -736,8 +738,11 @@ def _run_or_audit( for s_id, interval in (remove_intervals or {}).items(): self.snapshots[s_id].remove_interval(interval) + auto_restatement_triggers: t.Dict[SnapshotId, SnapshotId] = {} if auto_restatement_enabled: - auto_restated_intervals = apply_auto_restatements(self.snapshots, execution_time) + auto_restated_intervals, auto_restatement_triggers = apply_auto_restatements( + self.snapshots, execution_time + ) self.state_sync.add_snapshots_intervals(auto_restated_intervals) self.state_sync.update_auto_restatements( {s.name_version: s.next_auto_restatement_ts for s in self.snapshots.values()} @@ -768,6 +773,7 @@ def _run_or_audit( end=end, run_environment_statements=run_environment_statements, audit_only=audit_only, + auto_restatement_triggers=auto_restatement_triggers, ) return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index ec5a883f7f..54a2236c98 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -2180,7 +2180,7 @@ def snapshots_to_dag(snapshots: t.Collection[Snapshot]) -> DAG[SnapshotId]: def apply_auto_restatements( snapshots: t.Dict[SnapshotId, Snapshot], execution_time: TimeLike -) -> t.List[SnapshotIntervals]: +) -> t.Tuple[t.List[SnapshotIntervals], t.Dict[SnapshotId, SnapshotId]]: """Applies auto restatements to the snapshots. This operation results in the removal of intervals for snapshots that are ready to be restated based @@ -2195,6 +2195,8 @@ def apply_auto_restatements( A list of SnapshotIntervals with **new** intervals that need to be restated. """ dag = snapshots_to_dag(snapshots.values()) + snapshots_with_auto_restatements: t.List[SnapshotId] = [] + auto_restatement_triggers: t.Dict[SnapshotId, SnapshotId] = {} auto_restated_intervals_per_snapshot: t.Dict[SnapshotId, Interval] = {} for s_id in dag: if s_id not in snapshots: @@ -2218,6 +2220,23 @@ def apply_auto_restatements( ) auto_restated_intervals.append(next_auto_restated_interval) + # auto-restated snapshot is its own trigger + snapshots_with_auto_restatements.append(s_id) + auto_restatement_triggers[s_id] = s_id + else: + for parent_s_id in snapshot.parents: + # first auto-restated parent is the trigger + if parent_s_id in snapshots_with_auto_restatements: + auto_restatement_triggers[s_id] = parent_s_id + break + # if no trigger yet and parent has trigger, inherit their trigger + # - will be overwritten if a different parent is auto-restated + if ( + parent_s_id in auto_restatement_triggers + and s_id not in auto_restatement_triggers + ): + auto_restatement_triggers[s_id] = auto_restatement_triggers[parent_s_id] + if auto_restated_intervals: auto_restated_interval_start = sys.maxsize auto_restated_interval_end = -sys.maxsize @@ -2247,20 +2266,22 @@ def apply_auto_restatements( snapshot.apply_pending_restatement_intervals() snapshot.update_next_auto_restatement_ts(execution_time) - - return [ - SnapshotIntervals( - name=snapshots[s_id].name, - identifier=None, - version=snapshots[s_id].version, - dev_version=None, - intervals=[], - dev_intervals=[], - pending_restatement_intervals=[interval], - ) - for s_id, interval in auto_restated_intervals_per_snapshot.items() - if s_id in snapshots - ] + return ( + [ + SnapshotIntervals( + name=snapshots[s_id].name, + identifier=None, + version=snapshots[s_id].version, + dev_version=None, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[interval], + ) + for s_id, interval in auto_restated_intervals_per_snapshot.items() + if s_id in snapshots + ], + auto_restatement_triggers, + ) def parent_snapshots_by_name( diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index bce091595c..d394f827a0 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -3102,7 +3102,7 @@ def test_apply_auto_restatements(make_snapshot): (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), ] - restated_intervals = apply_auto_restatements( + restated_intervals, _ = apply_auto_restatements( { snapshot_a.snapshot_id: snapshot_a, snapshot_b.snapshot_id: snapshot_b, @@ -3239,7 +3239,7 @@ def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot): snapshot_b.add_interval("2020-01-01", "2020-01-05") assert snapshot_a.snapshot_id in snapshot_b.parents - restated_intervals = apply_auto_restatements( + restated_intervals, _ = apply_auto_restatements( { snapshot_a.snapshot_id: snapshot_a, snapshot_b.snapshot_id: snapshot_b, diff --git a/web/server/console.py b/web/server/console.py index 2cda0af697..2d19208cec 100644 --- a/web/server/console.py +++ b/web/server/console.py @@ -9,7 +9,7 @@ from sqlmesh.core.console import TerminalConsole from sqlmesh.core.environment import EnvironmentNamingInfo from sqlmesh.core.plan.definition import EvaluatablePlan -from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo +from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo, SnapshotId from sqlmesh.core.test import ModelTest from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.utils.date import now_timestamp @@ -142,6 +142,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, + auto_restatement_trigger: t.Optional[SnapshotId] = None, ) -> None: if audit_only: return From bfa3eca5c57e3727af7331b20c4fb6ff9e6c04af Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Tue, 22 Jul 2025 15:17:01 -0500 Subject: [PATCH 02/13] List all upstream auto-restated models --- sqlmesh/core/console.py | 14 ++-- sqlmesh/core/scheduler.py | 6 +- sqlmesh/core/snapshot/definition.py | 30 ++++---- tests/core/test_snapshot.py | 105 ++++++++++++++++++++++++++++ web/server/console.py | 2 +- 5 files changed, 128 insertions(+), 29 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 33042cbec1..a404057c8a 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -428,7 +428,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - auto_restatement_trigger: t.Optional[SnapshotId] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Updates the snapshot evaluation progress.""" @@ -576,7 +576,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - auto_restatement_trigger: t.Optional[SnapshotId] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: pass @@ -1058,7 +1058,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - auto_restatement_trigger: t.Optional[SnapshotId] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Update the snapshot evaluation progress.""" if ( @@ -3642,7 +3642,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - auto_restatement_trigger: t.Optional[SnapshotId] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: view_name, loaded_batches = self.evaluation_batch_progress[snapshot.snapshot_id] @@ -3812,12 +3812,12 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - auto_restatement_trigger: t.Optional[SnapshotId] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: message = f"Evaluating {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" - if auto_restatement_trigger: - message += f" | evaluation_triggered_by={auto_restatement_trigger.name}" + if auto_restatement_triggers: + message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in auto_restatement_triggers)}" if audit_only: message = f"Auditing {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index cb63e3df83..a9e1081282 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -415,7 +415,7 @@ def run_merged_intervals( selected_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, run_environment_statements: bool = False, audit_only: bool = False, - auto_restatement_triggers: t.Dict[SnapshotId, SnapshotId] = {}, + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}, ) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]: """Runs precomputed batches of missing intervals. @@ -532,7 +532,7 @@ def run_node(node: SchedulingUnit) -> None: evaluation_duration_ms, num_audits - num_audits_failed, num_audits_failed, - auto_restatement_trigger=auto_restatement_triggers.get(snapshot.snapshot_id), + auto_restatement_triggers=auto_restatement_triggers.get(snapshot.snapshot_id), ) elif isinstance(node, CreateNode): self.snapshot_evaluator.create_snapshot( @@ -738,7 +738,7 @@ def _run_or_audit( for s_id, interval in (remove_intervals or {}).items(): self.snapshots[s_id].remove_interval(interval) - auto_restatement_triggers: t.Dict[SnapshotId, SnapshotId] = {} + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} if auto_restatement_enabled: auto_restated_intervals, auto_restatement_triggers = apply_auto_restatements( self.snapshots, execution_time diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 54a2236c98..e21a83910c 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -2180,7 +2180,7 @@ def snapshots_to_dag(snapshots: t.Collection[Snapshot]) -> DAG[SnapshotId]: def apply_auto_restatements( snapshots: t.Dict[SnapshotId, Snapshot], execution_time: TimeLike -) -> t.Tuple[t.List[SnapshotIntervals], t.Dict[SnapshotId, SnapshotId]]: +) -> t.Tuple[t.List[SnapshotIntervals], t.Dict[SnapshotId, t.List[SnapshotId]]]: """Applies auto restatements to the snapshots. This operation results in the removal of intervals for snapshots that are ready to be restated based @@ -2195,8 +2195,7 @@ def apply_auto_restatements( A list of SnapshotIntervals with **new** intervals that need to be restated. """ dag = snapshots_to_dag(snapshots.values()) - snapshots_with_auto_restatements: t.List[SnapshotId] = [] - auto_restatement_triggers: t.Dict[SnapshotId, SnapshotId] = {} + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} auto_restated_intervals_per_snapshot: t.Dict[SnapshotId, Interval] = {} for s_id in dag: if s_id not in snapshots: @@ -2211,6 +2210,7 @@ def apply_auto_restatements( for parent_s_id in snapshot.parents if parent_s_id in auto_restated_intervals_per_snapshot ] + upstream_triggers = [] if next_auto_restated_interval: logger.info( "Calculated the next auto restated interval (%s, %s) for snapshot %s", @@ -2221,21 +2221,15 @@ def apply_auto_restatements( auto_restated_intervals.append(next_auto_restated_interval) # auto-restated snapshot is its own trigger - snapshots_with_auto_restatements.append(s_id) - auto_restatement_triggers[s_id] = s_id - else: - for parent_s_id in snapshot.parents: - # first auto-restated parent is the trigger - if parent_s_id in snapshots_with_auto_restatements: - auto_restatement_triggers[s_id] = parent_s_id - break - # if no trigger yet and parent has trigger, inherit their trigger - # - will be overwritten if a different parent is auto-restated - if ( - parent_s_id in auto_restatement_triggers - and s_id not in auto_restatement_triggers - ): - auto_restatement_triggers[s_id] = auto_restatement_triggers[parent_s_id] + upstream_triggers = [s_id] + + for parent_s_id in snapshot.parents: + if parent_s_id in auto_restatement_triggers: + upstream_triggers.extend(auto_restatement_triggers[parent_s_id]) + + # remove duplicate triggers + if upstream_triggers: + auto_restatement_triggers[s_id] = list(dict.fromkeys(upstream_triggers)) if auto_restated_intervals: auto_restated_interval_start = sys.maxsize diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index d394f827a0..2cd43e1b48 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -3279,6 +3279,111 @@ def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot): ] +def test_auto_restatement_triggers(make_snapshot): + model_a = SqlModel( + name="test_model_a", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=24, + ), + start="2020-01-01", + cron="@daily", + query=parse_one("SELECT 1 as ds"), + ) + snapshot_a = make_snapshot(model_a, version="1") + snapshot_a.add_interval("2020-01-01", "2020-01-05") + snapshot_a.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + + model_b = SqlModel( + name="test_model_b", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + ), + start="2020-01-01", + cron="@daily", + query=parse_one("SELECT ds FROM test_model_a"), + ) + snapshot_b = make_snapshot(model_b, nodes={model_a.fqn: model_a}, version="1") + snapshot_b.add_interval("2020-01-01", "2020-01-05") + + model_c = SqlModel( + name="test_model_c", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=24, + ), + start="2020-01-01", + cron="@daily", + query=parse_one("SELECT ds FROM test_model_a"), + ) + snapshot_c = make_snapshot(model_c, nodes={model_a.fqn: model_a}, version="1") + snapshot_c.add_interval("2020-01-01", "2020-01-05") + snapshot_c.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + + model_d = SqlModel( + name="test_model_d", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=24, + ), + start="2020-01-01", + cron="@daily", + query=parse_one("SELECT 1 as ds"), + ) + snapshot_d = make_snapshot(model_d, version="1") + snapshot_d.add_interval("2020-01-01", "2020-01-05") + snapshot_d.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + + model_e = SqlModel( + name="test_model_e", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + ), + start="2020-01-01", + cron="@daily", + query=parse_one( + "SELECT ds from test_model_b UNION ALL SELECT ds from test_model_c UNION ALL SELECT ds from test_model_d" + ), + ) + snapshot_e = make_snapshot( + model_e, + nodes={ + model_a.fqn: model_a, + model_b.fqn: model_b, + model_c.fqn: model_c, + model_d.fqn: model_d, + }, + version="1", + ) + snapshot_e.add_interval("2020-01-01", "2020-01-05") + + _, auto_restatement_triggers = apply_auto_restatements( + { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + snapshot_c.snapshot_id: snapshot_c, + snapshot_d.snapshot_id: snapshot_d, + snapshot_e.snapshot_id: snapshot_e, + }, + "2020-01-06 10:01:00", + ) + + assert auto_restatement_triggers == { + snapshot_a.snapshot_id: [snapshot_a.snapshot_id], + snapshot_d.snapshot_id: [snapshot_d.snapshot_id], + snapshot_b.snapshot_id: [snapshot_a.snapshot_id], + snapshot_c.snapshot_id: [snapshot_c.snapshot_id, snapshot_a.snapshot_id], + snapshot_e.snapshot_id: [ + snapshot_d.snapshot_id, + snapshot_c.snapshot_id, + snapshot_a.snapshot_id, + ], + } + + def test_render_signal(make_snapshot, mocker): @signal() def check_types(batch, env: str, sql: list[SQL], table: exp.Table, default: int = 0): diff --git a/web/server/console.py b/web/server/console.py index 2d19208cec..902a85418c 100644 --- a/web/server/console.py +++ b/web/server/console.py @@ -142,7 +142,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - auto_restatement_trigger: t.Optional[SnapshotId] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: if audit_only: return From 30a6b73f118b8caeb38b042d22056a1ee58db9a8 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Tue, 22 Jul 2025 15:38:43 -0500 Subject: [PATCH 03/13] Make test deterministic --- tests/core/test_snapshot.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 2cd43e1b48..8a8a349892 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -3371,17 +3371,18 @@ def test_auto_restatement_triggers(make_snapshot): "2020-01-06 10:01:00", ) - assert auto_restatement_triggers == { - snapshot_a.snapshot_id: [snapshot_a.snapshot_id], - snapshot_d.snapshot_id: [snapshot_d.snapshot_id], - snapshot_b.snapshot_id: [snapshot_a.snapshot_id], - snapshot_c.snapshot_id: [snapshot_c.snapshot_id, snapshot_a.snapshot_id], - snapshot_e.snapshot_id: [ - snapshot_d.snapshot_id, - snapshot_c.snapshot_id, - snapshot_a.snapshot_id, - ], - } + assert auto_restatement_triggers[snapshot_a.snapshot_id] == [snapshot_a.snapshot_id] + assert auto_restatement_triggers[snapshot_d.snapshot_id] == [snapshot_d.snapshot_id] + assert auto_restatement_triggers[snapshot_b.snapshot_id] == [snapshot_a.snapshot_id] + assert auto_restatement_triggers[snapshot_c.snapshot_id] == [ + snapshot_c.snapshot_id, + snapshot_a.snapshot_id, + ] + assert sorted(auto_restatement_triggers[snapshot_e.snapshot_id]) == [ + snapshot_a.snapshot_id, + snapshot_c.snapshot_id, + snapshot_d.snapshot_id, + ] def test_render_signal(make_snapshot, mocker): From bbed5acc4744a8c4690e272e3b1c7f9200512d21 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Thu, 24 Jul 2025 11:34:31 -0500 Subject: [PATCH 04/13] Add triggers class and selected snapshot triggers --- sqlmesh/core/console.py | 24 ++++++++++----- sqlmesh/core/context.py | 16 +++++----- sqlmesh/core/scheduler.py | 46 ++++++++++++++++++++++++++++- sqlmesh/core/snapshot/definition.py | 8 +++++ web/server/console.py | 5 ++-- 5 files changed, 80 insertions(+), 19 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index a404057c8a..22a0a35a56 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -37,7 +37,12 @@ SnapshotId, SnapshotInfoLike, ) -from sqlmesh.core.snapshot.definition import Interval, Intervals, SnapshotTableInfo +from sqlmesh.core.snapshot.definition import ( + Interval, + Intervals, + SnapshotTableInfo, + SnapshotEvaluationTriggers, +) from sqlmesh.core.test import ModelTest from sqlmesh.utils import rich as srich from sqlmesh.utils import Verbosity @@ -428,7 +433,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, + snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None, ) -> None: """Updates the snapshot evaluation progress.""" @@ -576,7 +581,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, + snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None, ) -> None: pass @@ -1058,7 +1063,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, + snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None, ) -> None: """Update the snapshot evaluation progress.""" if ( @@ -3642,7 +3647,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, + snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None, ) -> None: view_name, loaded_batches = self.evaluation_batch_progress[snapshot.snapshot_id] @@ -3812,12 +3817,15 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, + snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None, ) -> None: message = f"Evaluating {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" - if auto_restatement_triggers: - message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in auto_restatement_triggers)}" + if snapshot_evaluation_triggers: + if snapshot_evaluation_triggers.auto_restatement_triggers: + message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}" + if snapshot_evaluation_triggers.select_snapshot_triggers: + message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}" if audit_only: message = f"Auditing {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 9022f3f069..5aee6e5135 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -2307,11 +2307,9 @@ def check_intervals( } if select_models: - selected: t.Collection[str] = self._select_models_for_run( - select_models, True, snapshots.values() - ) + selected, _ = self._select_models_for_run(select_models, True, snapshots.values()) else: - selected = snapshots.keys() + selected = t.cast(t.Set[str], snapshots.keys()) results = {} execution_context = self.execution_context(snapshots=snapshots) @@ -2461,8 +2459,9 @@ def _run( scheduler = self.scheduler(environment=environment) snapshots = scheduler.snapshots + select_models_auto_upstream = None if select_models is not None: - select_models = self._select_models_for_run( + select_models, select_models_auto_upstream = self._select_models_for_run( select_models, no_auto_upstream, snapshots.values() ) @@ -2474,6 +2473,7 @@ def _run( ignore_cron=ignore_cron, circuit_breaker=circuit_breaker, selected_snapshots=select_models, + selected_snapshots_auto_upstream=select_models_auto_upstream, auto_restatement_enabled=environment.lower() == c.PROD, run_environment_statements=True, ) @@ -2889,7 +2889,7 @@ def _select_models_for_run( select_models: t.Collection[str], no_auto_upstream: bool, snapshots: t.Collection[Snapshot], - ) -> t.Set[str]: + ) -> t.Tuple[t.Set[str], t.Set[str]]: models: UniqueKeyDict[str, Model] = UniqueKeyDict( "models", **{s.name: s.model for s in snapshots if s.is_model} ) @@ -2899,8 +2899,8 @@ def _select_models_for_run( model_selector = self._new_selector(models=models, dag=dag) result = set(model_selector.expand_model_selections(select_models)) if not no_auto_upstream: - result = set(dag.subdag(*result)) - return result + result_with_upstream = set(dag.subdag(*result)) + return result, result_with_upstream - result @cached_property def _project_type(self) -> str: diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index a9e1081282..6599e2efe8 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -31,6 +31,7 @@ from sqlmesh.core.snapshot.definition import check_ready_intervals from sqlmesh.core.snapshot.definition import ( Interval, + SnapshotEvaluationTriggers, expand_range, parent_snapshots_by_name, ) @@ -262,6 +263,7 @@ def run( ignore_cron: bool = False, end_bounded: bool = False, selected_snapshots: t.Optional[t.Set[str]] = None, + selected_snapshots_auto_upstream: t.Optional[t.Set[str]] = None, circuit_breaker: t.Optional[t.Callable[[], bool]] = None, deployability_index: t.Optional[DeployabilityIndex] = None, auto_restatement_enabled: bool = False, @@ -278,6 +280,7 @@ def run( ignore_cron=ignore_cron, end_bounded=end_bounded, selected_snapshots=selected_snapshots, + selected_snapshots_auto_upstream=selected_snapshots_auto_upstream, circuit_breaker=circuit_breaker, deployability_index=deployability_index, auto_restatement_enabled=auto_restatement_enabled, @@ -532,7 +535,9 @@ def run_node(node: SchedulingUnit) -> None: evaluation_duration_ms, num_audits - num_audits_failed, num_audits_failed, - auto_restatement_triggers=auto_restatement_triggers.get(snapshot.snapshot_id), + snapshot_evaluation_triggers=snapshot_evaluation_triggers.get( + snapshot.snapshot_id + ), ) elif isinstance(node, CreateNode): self.snapshot_evaluator.create_snapshot( @@ -685,6 +690,7 @@ def _run_or_audit( ignore_cron: bool = False, end_bounded: bool = False, selected_snapshots: t.Optional[t.Set[str]] = None, + selected_snapshots_auto_upstream: t.Optional[t.Set[str]] = None, circuit_breaker: t.Optional[t.Callable[[], bool]] = None, deployability_index: t.Optional[DeployabilityIndex] = None, auto_restatement_enabled: bool = False, @@ -708,6 +714,7 @@ def _run_or_audit( end_bounded: If set to true, the evaluated intervals will be bounded by the target end date, disregarding lookback, allow_partials, and other attributes that could cause the intervals to exceed the target end date. selected_snapshots: A set of snapshot names to run. If not provided, all snapshots will be run. + selected_snapshots_auto_upstream: The set of selected_snapshots that were automatically added because they're upstream of a selected snapshot. circuit_breaker: An optional handler which checks if the run should be aborted. deployability_index: Determines snapshots that are deployable in the context of this render. auto_restatement_enabled: Whether to enable auto restatements. @@ -763,6 +770,42 @@ def _run_or_audit( if not merged_intervals: return CompletionStatus.NOTHING_TO_DO + merged_intervals_snapshots = { + snapshot.snapshot_id: snapshot for snapshot in merged_intervals.keys() + } + select_snapshot_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} + if selected_snapshots and selected_snapshots_auto_upstream: + # actually selected snapshots are their own triggers + selected_snapshots_no_auto_upstream = ( + selected_snapshots - selected_snapshots_auto_upstream + ) + select_snapshot_triggers = { + s_id: [s_id] + for s_id in [ + snapshot_id + for snapshot_id in merged_intervals_snapshots + if snapshot_id.name in selected_snapshots_no_auto_upstream + ] + } + + # trace upstream by reversing dag of all snapshots to evaluate + reversed_intervals_dag = snapshots_to_dag(merged_intervals_snapshots.values()).reversed + for s_id in reversed_intervals_dag: + if s_id not in select_snapshot_triggers: + triggers = [] + for parent_s_id in merged_intervals_snapshots[s_id].parents: + triggers.extend(select_snapshot_triggers[parent_s_id]) + select_snapshot_triggers[s_id] = list(dict.fromkeys(triggers)) + + all_snapshot_triggers: t.Dict[SnapshotId, SnapshotEvaluationTriggers] = { + s_id: SnapshotEvaluationTriggers( + ignore_cron=ignore_cron, + auto_restatement_triggers=auto_restatement_triggers.get(s_id, []), + select_snapshot_triggers=select_snapshot_triggers.get(s_id, []), + ) + for s_id in merged_intervals_snapshots + if ignore_cron or s_id in auto_restatement_triggers or s_id in select_snapshot_triggers + } errors, _ = self.run_merged_intervals( merged_intervals=merged_intervals, deployability_index=deployability_index, @@ -773,6 +816,7 @@ def _run_or_audit( end=end, run_environment_statements=run_environment_statements, audit_only=audit_only, + restatements=remove_intervals, auto_restatement_triggers=auto_restatement_triggers, ) diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index e21a83910c..5ef3f2c4c5 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -327,6 +327,14 @@ def table_name_for_environment( return table +class SnapshotEvaluationTriggers(PydanticModel): + ignore_cron: bool + auto_restatement_triggers: t.List[SnapshotId] = [] + select_snapshot_triggers: t.List[SnapshotId] = [] + directly_modified_triggers: t.List[SnapshotId] = [] + manual_restatement_triggers: t.List[SnapshotId] = [] + + class SnapshotInfoMixin(ModelKindMixin): name: str dev_version_: t.Optional[str] diff --git a/web/server/console.py b/web/server/console.py index 902a85418c..5af93864f6 100644 --- a/web/server/console.py +++ b/web/server/console.py @@ -9,7 +9,8 @@ from sqlmesh.core.console import TerminalConsole from sqlmesh.core.environment import EnvironmentNamingInfo from sqlmesh.core.plan.definition import EvaluatablePlan -from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo, SnapshotId +from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo +from sqlmesh.core.snapshot.definition import SnapshotEvaluationTriggers from sqlmesh.core.test import ModelTest from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.utils.date import now_timestamp @@ -142,7 +143,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, + snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None, ) -> None: if audit_only: return From c4952c6f955c1f4437ae50415caec323aaf0ed8a Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Thu, 31 Jul 2025 18:35:37 -0500 Subject: [PATCH 05/13] Collect selected snapshot triggers --- sqlmesh/core/console.py | 14 +++- sqlmesh/core/context.py | 7 +- sqlmesh/core/plan/stages.py | 3 +- sqlmesh/core/scheduler.py | 46 ++++++------ sqlmesh/core/snapshot/definition.py | 5 +- tests/core/test_integration.py | 107 +++++++++++++++++++++++----- tests/core/test_scheduler.py | 7 +- web/server/api/endpoints/plan.py | 2 +- 8 files changed, 140 insertions(+), 51 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 22a0a35a56..04c96082f4 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -3819,7 +3819,17 @@ def update_snapshot_evaluation_progress( audit_only: bool = False, snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None, ) -> None: - message = f"Evaluating {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" + message = f"Evaluated {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" + + if snapshot_evaluation_triggers: + if snapshot_evaluation_triggers.ignore_cron_flag is not None: + message += f" | ignore_cron_flag={snapshot_evaluation_triggers.ignore_cron_flag}" + if snapshot_evaluation_triggers.cron_ready is not None: + message += f" | cron_ready={snapshot_evaluation_triggers.cron_ready}" + if snapshot_evaluation_triggers.auto_restatement_triggers: + message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}" + if snapshot_evaluation_triggers.select_snapshot_triggers: + message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}" if snapshot_evaluation_triggers: if snapshot_evaluation_triggers.auto_restatement_triggers: @@ -3828,7 +3838,7 @@ def update_snapshot_evaluation_progress( message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}" if audit_only: - message = f"Auditing {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" + message = f"Audited {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" self._write(message) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 5aee6e5135..a4fab138fb 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -2898,9 +2898,10 @@ def _select_models_for_run( dag.add(fqn, model.depends_on) model_selector = self._new_selector(models=models, dag=dag) result = set(model_selector.expand_model_selections(select_models)) - if not no_auto_upstream: - result_with_upstream = set(dag.subdag(*result)) - return result, result_with_upstream - result + if no_auto_upstream: + return result, set() + result_with_upstream = set(dag.subdag(*result)) + return result_with_upstream, result_with_upstream - result @cached_property def _project_type(self) -> str: diff --git a/sqlmesh/core/plan/stages.py b/sqlmesh/core/plan/stages.py index 82223dd807..c1eb3f9927 100644 --- a/sqlmesh/core/plan/stages.py +++ b/sqlmesh/core/plan/stages.py @@ -553,7 +553,7 @@ def _missing_intervals( snapshots_by_name: t.Dict[str, Snapshot], deployability_index: DeployabilityIndex, ) -> SnapshotToIntervals: - return merged_missing_intervals( + missing_intervals, _ = merged_missing_intervals( snapshots=snapshots_by_name.values(), start=plan.start, end=plan.end, @@ -568,6 +568,7 @@ def _missing_intervals( start_override_per_model=plan.start_override_per_model, end_override_per_model=plan.end_override_per_model, ) + return missing_intervals def _get_audit_only_snapshots( self, new_snapshots: t.Dict[SnapshotId, Snapshot] diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 6599e2efe8..aca4fec167 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -147,7 +147,7 @@ def merged_missing_intervals( ignore_cron: bool = False, end_bounded: bool = False, selected_snapshots: t.Optional[t.Set[str]] = None, - ) -> SnapshotToIntervals: + ) -> t.Tuple[SnapshotToIntervals, t.List[SnapshotId]]: """Find the largest contiguous date interval parameters based only on what is missing. For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found, @@ -167,8 +167,11 @@ def merged_missing_intervals( end_bounded: If set to true, the returned intervals will be bounded by the target end date, disregarding lookback, allow_partials, and other attributes that could cause the intervals to exceed the target end date. selected_snapshots: A set of snapshot names to run. If not provided, all snapshots will be run. + + Returns: + A tuple containing a dict containing all snapshots needing to be run with their associated interval params and a list of snapshots that are ready to run based on their naive cron schedule (ignoring plan/run context and other attributes). """ - snapshots_to_intervals = merged_missing_intervals( + snapshots_to_intervals, snapshots_naive_cron_ready = merged_missing_intervals( snapshots=self.snapshot_per_version.values(), start=start, end=end, @@ -186,7 +189,7 @@ def merged_missing_intervals( snapshots_to_intervals = { s: i for s, i in snapshots_to_intervals.items() if s.name in selected_snapshots } - return snapshots_to_intervals + return snapshots_to_intervals, snapshots_naive_cron_ready def evaluate( self, @@ -755,7 +758,7 @@ def _run_or_audit( {s.name_version: s.next_auto_restatement_ts for s in self.snapshots.values()} ) - merged_intervals = self.merged_missing_intervals( + merged_intervals, snapshots_naive_cron_ready = self.merged_missing_intervals( start, end, execution_time, @@ -770,9 +773,7 @@ def _run_or_audit( if not merged_intervals: return CompletionStatus.NOTHING_TO_DO - merged_intervals_snapshots = { - snapshot.snapshot_id: snapshot for snapshot in merged_intervals.keys() - } + merged_intervals_snapshots = {snapshot.snapshot_id for snapshot in merged_intervals} select_snapshot_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} if selected_snapshots and selected_snapshots_auto_upstream: # actually selected snapshots are their own triggers @@ -788,24 +789,25 @@ def _run_or_audit( ] } - # trace upstream by reversing dag of all snapshots to evaluate - reversed_intervals_dag = snapshots_to_dag(merged_intervals_snapshots.values()).reversed - for s_id in reversed_intervals_dag: - if s_id not in select_snapshot_triggers: - triggers = [] - for parent_s_id in merged_intervals_snapshots[s_id].parents: - triggers.extend(select_snapshot_triggers[parent_s_id]) + # trace upstream by walking downstream on reversed dag + reversed_dag = snapshots_to_dag(self.snapshots.values()).reversed + for s_id in reversed_dag: + if s_id in merged_intervals_snapshots: + triggers = select_snapshot_triggers.get(s_id, []) + for parent_s_id in reversed_dag.graph.get(s_id, set()): + triggers.extend(select_snapshot_triggers.get(parent_s_id, [])) select_snapshot_triggers[s_id] = list(dict.fromkeys(triggers)) all_snapshot_triggers: t.Dict[SnapshotId, SnapshotEvaluationTriggers] = { s_id: SnapshotEvaluationTriggers( - ignore_cron=ignore_cron, + ignore_cron_flag=ignore_cron, + cron_ready=s_id in snapshots_naive_cron_ready, auto_restatement_triggers=auto_restatement_triggers.get(s_id, []), select_snapshot_triggers=select_snapshot_triggers.get(s_id, []), ) for s_id in merged_intervals_snapshots - if ignore_cron or s_id in auto_restatement_triggers or s_id in select_snapshot_triggers } + errors, _ = self.run_merged_intervals( merged_intervals=merged_intervals, deployability_index=deployability_index, @@ -967,7 +969,7 @@ def merged_missing_intervals( end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, ignore_cron: bool = False, end_bounded: bool = False, -) -> SnapshotToIntervals: +) -> t.Tuple[SnapshotToIntervals, t.List[SnapshotId]]: """Find the largest contiguous date interval parameters based only on what is missing. For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found, @@ -1017,7 +1019,7 @@ def compute_interval_params( end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, ignore_cron: bool = False, end_bounded: bool = False, -) -> SnapshotToIntervals: +) -> t.Tuple[SnapshotToIntervals, t.List[SnapshotId]]: """Find the largest contiguous date interval parameters based only on what is missing. For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found, @@ -1039,7 +1041,7 @@ def compute_interval_params( allow_partials, and other attributes that could cause the intervals to exceed the target end date. Returns: - A dict containing all snapshots needing to be run with their associated interval params. + A tuple containing a dict containing all snapshots needing to be run with their associated interval params and a list of snapshots that are ready to run based on their naive cron schedule (ignoring plan/run context and other attributes). """ snapshot_merged_intervals = {} @@ -1067,7 +1069,11 @@ def compute_interval_params( contiguous_batch.append((next_batch[0][0], next_batch[-1][-1])) snapshot_merged_intervals[snapshot] = contiguous_batch - return snapshot_merged_intervals + snapshots_naive_cron_ready = [ + snap.snapshot_id for snap in missing_intervals(snapshots, execution_time=execution_time) + ] + + return snapshot_merged_intervals, snapshots_naive_cron_ready def interval_diff( diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 5ef3f2c4c5..37b9acc275 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -328,11 +328,10 @@ def table_name_for_environment( class SnapshotEvaluationTriggers(PydanticModel): - ignore_cron: bool + ignore_cron_flag: t.Optional[bool] = None + cron_ready: t.Optional[bool] = None auto_restatement_triggers: t.List[SnapshotId] = [] select_snapshot_triggers: t.List[SnapshotId] = [] - directly_modified_triggers: t.List[SnapshotId] = [] - manual_restatement_triggers: t.List[SnapshotId] = [] class SnapshotInfoMixin(ModelKindMixin): diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index fc129424f4..3577de5f9d 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -1862,6 +1862,77 @@ def test_select_unchanged_model_for_backfill(init_and_plan_context: t.Callable): assert {o.name for o in schema_objects} == {"waiter_revenue_by_day", "top_waiters"} +@time_machine.travel("2023-01-08 00:00:00 UTC") +def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixture): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + # add auto restatement to orders + model = context.get_model("sushi.orders") + kind = { + **model.kind.dict(), + "auto_restatement_cron": "@hourly", + } + kwargs = { + **model.dict(), + "kind": kind, + } + context.upsert_model(PythonModel.parse_obj(kwargs)) + plan = context.plan_builder(skip_tests=True).build() + context.apply(plan) + + # Mock run_merged_intervals to capture triggers arg + scheduler = context.scheduler() + run_merged_intervals_mock = mocker.patch.object( + scheduler, "run_merged_intervals", return_value=([], []) + ) + + # User selects top_waiters and waiter_revenue_by_day, others added as auto-upstream + selected_models = {"top_waiters", "waiter_revenue_by_day"} + selected_models_auto_upstream = {"order_items", "orders", "items"} + selected_snapshots = { + f'"memory"."sushi"."{model}"' for model in selected_models | selected_models_auto_upstream + } + selected_snapshots_auto_upstream = selected_snapshots - { + f'"memory"."sushi"."{model}"' for model in selected_models + } + + with time_machine.travel("2023-01-09 00:00:01 UTC"): + scheduler.run( + environment=c.PROD, + selected_snapshots=selected_snapshots, + selected_snapshots_auto_upstream=selected_snapshots_auto_upstream, + start="2023-01-01", + auto_restatement_enabled=True, + ) + + assert run_merged_intervals_mock.called + + actual_triggers = run_merged_intervals_mock.call_args.kwargs["snapshot_evaluation_triggers"] + + # validate ignore_cron not passed and all model crons ready + assert all( + not trigger.ignore_cron_flag and trigger.cron_ready for trigger in actual_triggers.values() + ) + + for id, trigger in actual_triggers.items(): + # top_waiters is its own trigger, waiter_revenue_by_day is upstream of it, everyone else is upstream of both + select_triggers = [t.name for t in trigger.select_snapshot_triggers] + assert ( + select_triggers == ['"memory"."sushi"."top_waiters"'] + if id.name == '"memory"."sushi"."top_waiters"' + else ['"memory"."sushi"."waiter_revenue_by_day"', '"memory"."sushi"."top_waiters"'] + ) + + # everyone other than items is downstream of orders + auto_restatement_triggers = [t.name for t in trigger.auto_restatement_triggers] + assert ( + auto_restatement_triggers == [] + if id.name == '"memory"."sushi"."items"' + else ['"memory"."sushi"."orders"'] + ) + + @time_machine.travel("2023-01-08 15:00:00 UTC") def test_max_interval_end_per_model_not_applied_when_end_is_provided( init_and_plan_context: t.Callable, @@ -7705,7 +7776,7 @@ def test_incremental_by_time_model_ignore_destructive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -7749,7 +7820,7 @@ def test_incremental_by_time_model_ignore_destructive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 2 as id, 3 as new_column, @@ -7803,9 +7874,9 @@ def test_incremental_by_time_model_ignore_destructive_change(tmp_path: Path): start '2023-01-01', cron '@daily' ); - + SELECT - *, + *, 2 as id, CAST(4 AS STRING) as new_column, @start_ds as ds @@ -7844,8 +7915,8 @@ def test_incremental_by_time_model_ignore_destructive_change(tmp_path: Path): start '2023-01-01', cron '@daily' ); - - SELECT + + SELECT *, 2 as id, CAST(5 AS STRING) as new_column, @@ -7905,7 +7976,7 @@ def test_incremental_by_unique_key_model_ignore_destructive_change(tmp_path: Pat cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -7949,7 +8020,7 @@ def test_incremental_by_unique_key_model_ignore_destructive_change(tmp_path: Pat cron '@daily' ); - SELECT + SELECT *, 2 as id, 3 as new_column, @@ -8016,7 +8087,7 @@ def test_incremental_unmanaged_model_ignore_destructive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -8059,7 +8130,7 @@ def test_incremental_unmanaged_model_ignore_destructive_change(tmp_path: Path): ); SELECT - *, + *, 2 as id, 3 as new_column, @start_ds as ds @@ -8240,7 +8311,7 @@ def test_scd_type_2_by_column_ignore_destructive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -8285,7 +8356,7 @@ def test_scd_type_2_by_column_ignore_destructive_change(tmp_path: Path): ); SELECT - *, + *, 1 as id, 3 as new_column, @start_ds as ds @@ -8352,7 +8423,7 @@ def test_incremental_partition_ignore_destructive_change(tmp_path: Path): cron '@daily' ); - SELECT + SELECT *, 1 as id, 'test_name' as name, @@ -8396,7 +8467,7 @@ def test_incremental_partition_ignore_destructive_change(tmp_path: Path): ); SELECT - *, + *, 1 as id, 3 as new_column, @start_ds as ds @@ -8467,7 +8538,7 @@ def test_incremental_by_time_model_ignore_destructive_change_unit_test(tmp_path: cron '@daily' ); - SELECT + SELECT id, name, ds @@ -8479,7 +8550,7 @@ def test_incremental_by_time_model_ignore_destructive_change_unit_test(tmp_path: (models_dir / "test_model.sql").write_text(initial_model) initial_test = f""" - + test_test_model: model: test_model inputs: @@ -8534,8 +8605,8 @@ def test_incremental_by_time_model_ignore_destructive_change_unit_test(tmp_path: start '2023-01-01', cron '@daily' ); - - SELECT + + SELECT id, new_column, ds diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index b74aa3480e..e51c394407 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -79,7 +79,7 @@ def _get_batched_missing_intervals( end: TimeLike, execution_time: t.Optional[TimeLike] = None, ) -> SnapshotToIntervals: - merged_intervals = scheduler.merged_missing_intervals(start, end, execution_time) + merged_intervals, _ = scheduler.merged_missing_intervals(start, end, execution_time) return scheduler.batch_intervals(merged_intervals, mocker.Mock(), mocker.Mock()) return _get_batched_missing_intervals @@ -107,9 +107,10 @@ def test_interval_params_missing(scheduler: Scheduler, sushi_context_fixed_date: start_ds = "2022-01-01" end_ds = "2022-03-01" - assert compute_interval_params( + interval_params, _ = compute_interval_params( sushi_context_fixed_date.snapshots.values(), start=start_ds, end=end_ds - )[waiters] == [ + ) + assert interval_params[waiters] == [ (to_timestamp(start_ds), to_timestamp("2022-03-02")), ] diff --git a/web/server/api/endpoints/plan.py b/web/server/api/endpoints/plan.py index ba4feb34d3..39b0e810c6 100644 --- a/web/server/api/endpoints/plan.py +++ b/web/server/api/endpoints/plan.py @@ -132,7 +132,7 @@ def _get_plan_changes(context: Context, plan: Plan) -> models.PlanChanges: def _get_plan_backfills(context: Context, plan: Plan) -> t.Dict[str, t.Any]: """Get plan backfills""" - merged_intervals = context.scheduler().merged_missing_intervals() + merged_intervals, _ = context.scheduler().merged_missing_intervals() batches = context.scheduler().batch_intervals(merged_intervals, None, EnvironmentNamingInfo()) tasks = {snapshot.name: len(intervals) for snapshot, intervals in batches.items()} snapshots = plan.context_diff.snapshots From 46dbc3ec0319cb775d03a6fb0182512db3974afa Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Wed, 6 Aug 2025 11:44:12 -0500 Subject: [PATCH 06/13] Add directly modified and restatement triggers --- sqlmesh/core/console.py | 10 ++- sqlmesh/core/context.py | 2 +- sqlmesh/core/plan/builder.py | 19 +++-- sqlmesh/core/plan/definition.py | 3 + sqlmesh/core/plan/evaluator.py | 22 ++++++ sqlmesh/core/snapshot/definition.py | 2 + tests/core/test_integration.py | 103 ++++++++++++++++++++++++---- 7 files changed, 136 insertions(+), 25 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 04c96082f4..b4d962f4f8 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -3830,12 +3830,10 @@ def update_snapshot_evaluation_progress( message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}" if snapshot_evaluation_triggers.select_snapshot_triggers: message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}" - - if snapshot_evaluation_triggers: - if snapshot_evaluation_triggers.auto_restatement_triggers: - message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}" - if snapshot_evaluation_triggers.select_snapshot_triggers: - message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}" + if snapshot_evaluation_triggers.directly_modified_triggers: + message += f" | directly_modified_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.directly_modified_triggers)}" + if snapshot_evaluation_triggers.restatement_triggers: + message += f" | restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.restatement_triggers)}" if audit_only: message = f"Audited {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index a4fab138fb..8a12136f83 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -2309,7 +2309,7 @@ def check_intervals( if select_models: selected, _ = self._select_models_for_run(select_models, True, snapshots.values()) else: - selected = t.cast(t.Set[str], snapshots.keys()) + selected = set(snapshots.keys()) results = {} execution_context = self.execution_context(snapshots=snapshots) diff --git a/sqlmesh/core/plan/builder.py b/sqlmesh/core/plan/builder.py index 9451f9eb53..43f77e2323 100644 --- a/sqlmesh/core/plan/builder.py +++ b/sqlmesh/core/plan/builder.py @@ -293,7 +293,7 @@ def build(self) -> Plan: else DeployabilityIndex.all_deployable() ) - restatements = self._build_restatements( + restatements, restatement_triggers = self._build_restatements( dag, earliest_interval_start(self._context_diff.snapshots.values(), self.execution_time), ) @@ -330,6 +330,7 @@ def build(self) -> Plan: indirectly_modified=indirectly_modified, deployability_index=deployability_index, restatements=restatements, + restatement_triggers=restatement_triggers, start_override_per_model=self._start_override_per_model, end_override_per_model=end_override_per_model, selected_models_to_backfill=self._backfill_models, @@ -352,14 +353,14 @@ def _build_dag(self) -> DAG[SnapshotId]: def _build_restatements( self, dag: DAG[SnapshotId], earliest_interval_start: TimeLike - ) -> t.Dict[SnapshotId, Interval]: + ) -> t.Tuple[t.Dict[SnapshotId, Interval], t.Dict[SnapshotId, t.List[SnapshotId]]]: restate_models = self._restate_models if restate_models == set(): # This is a warning but we print this as error since the Console is lacking API for warnings. self._console.log_error( "Provided restated models do not match any models. No models will be included in plan." ) - return {} + return {}, {} restatements: t.Dict[SnapshotId, Interval] = {} forward_only_preview_needed = self._forward_only_preview_needed @@ -383,7 +384,7 @@ def _build_restatements( is_preview = True if not restate_models: - return {} + return {}, {} start = self._start or earliest_interval_start end = self._end or now() @@ -393,6 +394,7 @@ def _build_restatements( if model_fqn not in self._model_fqn_to_snapshot: raise PlanError(f"Cannot restate model '{model_fqn}'. Model does not exist.") + restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} # Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's # restatement range that it's downstream dependencies all expand their restatement ranges as well. for s_id in dag: @@ -428,6 +430,13 @@ def _build_restatements( logger.info("Skipping restatement for model '%s'", snapshot.name) continue + if snapshot.name in restate_models: + restatement_triggers[s_id] = [s_id] + if restating_parents: + restatement_triggers[s_id] = restatement_triggers.get(s_id, []) + [ + s.snapshot_id for s in restating_parents + ] + possible_intervals = { restatements[p.snapshot_id] for p in restating_parents if p.is_incremental } @@ -456,7 +465,7 @@ def _build_restatements( restatements[s_id] = (snapshot_start, snapshot_end) - return restatements + return restatements, restatement_triggers def _build_directly_and_indirectly_modified( self, dag: DAG[SnapshotId] diff --git a/sqlmesh/core/plan/definition.py b/sqlmesh/core/plan/definition.py index 300ac62faf..62b8544939 100644 --- a/sqlmesh/core/plan/definition.py +++ b/sqlmesh/core/plan/definition.py @@ -58,6 +58,7 @@ class Plan(PydanticModel, frozen=True): deployability_index: DeployabilityIndex restatements: t.Dict[SnapshotId, Interval] + restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} start_override_per_model: t.Optional[t.Dict[str, datetime]] end_override_per_model: t.Optional[t.Dict[str, datetime]] @@ -256,6 +257,7 @@ def to_evaluatable(self) -> EvaluatablePlan: skip_backfill=self.skip_backfill, empty_backfill=self.empty_backfill, restatements={s.name: i for s, i in self.restatements.items()}, + restatement_triggers=self.restatement_triggers, is_dev=self.is_dev, allow_destructive_models=self.allow_destructive_models, forward_only=self.forward_only, @@ -298,6 +300,7 @@ class EvaluatablePlan(PydanticModel): skip_backfill: bool empty_backfill: bool restatements: t.Dict[str, Interval] + restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} is_dev: bool allow_destructive_models: t.Set[str] forward_only: bool diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 46142b7eeb..ba237f9be6 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -37,6 +37,7 @@ SnapshotCreationFailedError, SnapshotNameVersion, ) +from sqlmesh.core.snapshot.definition import SnapshotEvaluationTriggers from sqlmesh.utils import to_snake_case from sqlmesh.core.state_sync import StateSync from sqlmesh.utils import CorrelationId @@ -244,6 +245,27 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla self.console.log_success("SKIP: No model batches to execute") return + directly_modified_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} + for parent, children in plan.indirectly_modified_snapshots.items(): + parent_id = stage.all_snapshots[parent].snapshot_id + directly_modified_triggers[parent_id] = directly_modified_triggers.get( + parent_id, [] + ) + [parent_id] + for child in children: + directly_modified_triggers[child] = directly_modified_triggers.get(child, []) + [ + parent_id + ] + directly_modified_triggers = { + k: list(dict.fromkeys(v)) for k, v in directly_modified_triggers.items() + } + snapshot_evaluation_triggers = { + s_id: SnapshotEvaluationTriggers( + directly_modified_triggers=directly_modified_triggers.get(s_id, []), + restatement_triggers=plan.restatement_triggers.get(s_id, []), + ) + for s_id in [s.snapshot_id for s in stage.all_snapshots.values()] + } + scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator) errors, _ = scheduler.run_merged_intervals( merged_intervals=stage.snapshot_to_intervals, diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 37b9acc275..4b28ec5d6e 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -332,6 +332,8 @@ class SnapshotEvaluationTriggers(PydanticModel): cron_ready: t.Optional[bool] = None auto_restatement_triggers: t.List[SnapshotId] = [] select_snapshot_triggers: t.List[SnapshotId] = [] + directly_modified_triggers: t.List[SnapshotId] = [] + restatement_triggers: t.List[SnapshotId] = [] class SnapshotInfoMixin(ModelKindMixin): diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index 3577de5f9d..d1b49de216 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -26,6 +26,7 @@ from sqlmesh import CustomMaterialization +import sqlmesh from sqlmesh.cli.project_init import init_example_project from sqlmesh.core import constants as c from sqlmesh.core import dialect as d @@ -1867,26 +1868,97 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt context, plan = init_and_plan_context("examples/sushi") context.apply(plan) + # modify 3 models + # - 2 breaking changes for testing plan directly modified triggers + # - 1 adding an auto-restatement for subsequent `run` test + marketing = context.get_model("sushi.marketing") + marketing_kwargs = { + **marketing.dict(), + "query": d.parse_one( + f"{marketing.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb" + ), + } + context.upsert_model(SqlModel.parse_obj(marketing_kwargs)) + + customers = context.get_model("sushi.customers") + customers_kwargs = { + **customers.dict(), + "query": d.parse_one( + f"{customers.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb" + ), + } + context.upsert_model(SqlModel.parse_obj(customers_kwargs)) + # add auto restatement to orders - model = context.get_model("sushi.orders") - kind = { - **model.kind.dict(), + orders = context.get_model("sushi.orders") + orders_kind = { + **orders.kind.dict(), "auto_restatement_cron": "@hourly", } - kwargs = { - **model.dict(), - "kind": kind, + orders_kwargs = { + **orders.dict(), + "kind": orders_kind, } - context.upsert_model(PythonModel.parse_obj(kwargs)) - plan = context.plan_builder(skip_tests=True).build() - context.apply(plan) + context.upsert_model(PythonModel.parse_obj(orders_kwargs)) - # Mock run_merged_intervals to capture triggers arg - scheduler = context.scheduler() - run_merged_intervals_mock = mocker.patch.object( - scheduler, "run_merged_intervals", return_value=([], []) + spy = mocker.spy(sqlmesh.core.scheduler.Scheduler, "run_merged_intervals") + + context.plan(auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full()) + + # PLAN: directly modified triggers + actual_triggers = spy.call_args.kwargs["snapshot_evaluation_triggers"] + actual_triggers_name = { + k.name: sorted([s.name for s in v.directly_modified_triggers]) + for k, v in actual_triggers.items() + if v.directly_modified_triggers + } + marketing_name = '"memory"."sushi"."marketing"' + customers_name = '"memory"."sushi"."customers"' + marketing_customers_names = sorted([marketing_name, customers_name]) + children_names = [ + f'"memory"."sushi"."{model}"' + for model in { + "waiter_as_customer_by_day", + "active_customers", + "count_customers_active", + "count_customers_inactive", + } + ] + assert actual_triggers_name == { + marketing_name: [marketing_name], + customers_name: [customers_name], + **{k: marketing_customers_names for k in children_names}, + } + + # PLAN: restatement triggers + spy.reset_mock() + context.plan( + restate_models=[ + '"memory"."sushi"."marketing"', + '"memory"."sushi"."order_items"', + '"memory"."sushi"."waiter_revenue_by_day"', + ], + auto_apply=True, + no_prompts=True, ) + order_items_name = '"memory"."sushi"."order_items"' + waiter_revenue_by_day_name = '"memory"."sushi"."waiter_revenue_by_day"' + actual_triggers = spy.call_args.kwargs["snapshot_evaluation_triggers"] + actual_triggers_name = { + k.name: sorted([s.name for s in v.restatement_triggers]) + for k, v in actual_triggers.items() + if v.restatement_triggers + } + assert actual_triggers_name == { + waiter_revenue_by_day_name: [waiter_revenue_by_day_name, order_items_name], + order_items_name: [order_items_name], + '"memory"."sushi"."top_waiters"': [waiter_revenue_by_day_name], + '"memory"."sushi"."customer_revenue_by_day"': [order_items_name], + '"memory"."sushi"."customer_revenue_lifetime"': [order_items_name], + } + + # RUN: select and auto-restatement triggers # User selects top_waiters and waiter_revenue_by_day, others added as auto-upstream selected_models = {"top_waiters", "waiter_revenue_by_day"} selected_models_auto_upstream = {"order_items", "orders", "items"} @@ -1897,6 +1969,11 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt f'"memory"."sushi"."{model}"' for model in selected_models } + scheduler = context.scheduler() + run_merged_intervals_mock = mocker.patch.object( + scheduler, "run_merged_intervals", return_value=([], []) + ) + with time_machine.travel("2023-01-09 00:00:01 UTC"): scheduler.run( environment=c.PROD, From d105471be62be0d84448c4a49f2eccd565f98ea1 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Wed, 6 Aug 2025 12:37:21 -0500 Subject: [PATCH 07/13] Fix tests --- tests/core/test_integration.py | 29 +++++++++++++++++++++-------- tests/core/test_scheduler.py | 8 ++++++-- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index d1b49de216..e3ba5a1a85 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -1950,13 +1950,16 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt for k, v in actual_triggers.items() if v.restatement_triggers } - assert actual_triggers_name == { - waiter_revenue_by_day_name: [waiter_revenue_by_day_name, order_items_name], - order_items_name: [order_items_name], - '"memory"."sushi"."top_waiters"': [waiter_revenue_by_day_name], - '"memory"."sushi"."customer_revenue_by_day"': [order_items_name], - '"memory"."sushi"."customer_revenue_lifetime"': [order_items_name], - } + + assert sorted(actual_triggers_name[waiter_revenue_by_day_name]) == sorted( + [waiter_revenue_by_day_name, order_items_name] + ) + assert actual_triggers_name[order_items_name] == [order_items_name] + assert actual_triggers_name['"memory"."sushi"."top_waiters"'] == [waiter_revenue_by_day_name] + assert actual_triggers_name['"memory"."sushi"."customer_revenue_by_day"'] == [order_items_name] + assert actual_triggers_name['"memory"."sushi"."customer_revenue_lifetime"'] == [ + order_items_name + ] # RUN: select and auto-restatement triggers # User selects top_waiters and waiter_revenue_by_day, others added as auto-upstream @@ -7110,7 +7113,17 @@ def plan_with_output(ctx: Context, environment: str): assert "New environment `dev` will be created from `prod`" in output.stdout assert "Differences from the `prod` environment" in output.stdout - assert "Directly Modified: test__dev.a" in output.stdout + assert ( + """MODEL ( + name test.a, ++ owner test, + kind FULL + ) + SELECT +- 5 AS col ++ 10 AS col""" + in output.stdout + ) # Case 6: Ensure that target environment and create_from environment are not the same output = plan_with_output(ctx, "prod") diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index e51c394407..571242082f 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -59,7 +59,10 @@ def test_interval_params(scheduler: Scheduler, sushi_context_fixed_date: Context start_ds = "2022-01-01" end_ds = "2022-02-05" - assert compute_interval_params([orders, waiter_revenue], start=start_ds, end=end_ds) == { + interval_params, _ = compute_interval_params( + [orders, waiter_revenue], start=start_ds, end=end_ds + ) + assert interval_params == { orders: [ (to_timestamp(start_ds), to_timestamp("2022-02-06")), ], @@ -91,7 +94,8 @@ def test_interval_params_nonconsecutive(scheduler: Scheduler, orders: Snapshot): orders.add_interval("2022-01-10", "2022-01-15") - assert compute_interval_params([orders], start=start_ds, end=end_ds) == { + interval_params, _ = compute_interval_params([orders], start=start_ds, end=end_ds) + assert interval_params == { orders: [ (to_timestamp(start_ds), to_timestamp("2022-01-10")), (to_timestamp("2022-01-16"), to_timestamp("2022-02-06")), From dd6a35f978811eeab0fdebd932199537476a55ba Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Mon, 11 Aug 2025 11:37:34 -0500 Subject: [PATCH 08/13] Revert plan-related triggers --- sqlmesh/core/console.py | 4 -- sqlmesh/core/plan/builder.py | 19 ++----- sqlmesh/core/plan/definition.py | 3 -- sqlmesh/core/plan/evaluator.py | 22 -------- sqlmesh/core/snapshot/definition.py | 2 - tests/core/test_integration.py | 81 ----------------------------- 6 files changed, 5 insertions(+), 126 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index b4d962f4f8..b043f2ab6b 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -3830,10 +3830,6 @@ def update_snapshot_evaluation_progress( message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}" if snapshot_evaluation_triggers.select_snapshot_triggers: message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}" - if snapshot_evaluation_triggers.directly_modified_triggers: - message += f" | directly_modified_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.directly_modified_triggers)}" - if snapshot_evaluation_triggers.restatement_triggers: - message += f" | restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.restatement_triggers)}" if audit_only: message = f"Audited {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" diff --git a/sqlmesh/core/plan/builder.py b/sqlmesh/core/plan/builder.py index 43f77e2323..9451f9eb53 100644 --- a/sqlmesh/core/plan/builder.py +++ b/sqlmesh/core/plan/builder.py @@ -293,7 +293,7 @@ def build(self) -> Plan: else DeployabilityIndex.all_deployable() ) - restatements, restatement_triggers = self._build_restatements( + restatements = self._build_restatements( dag, earliest_interval_start(self._context_diff.snapshots.values(), self.execution_time), ) @@ -330,7 +330,6 @@ def build(self) -> Plan: indirectly_modified=indirectly_modified, deployability_index=deployability_index, restatements=restatements, - restatement_triggers=restatement_triggers, start_override_per_model=self._start_override_per_model, end_override_per_model=end_override_per_model, selected_models_to_backfill=self._backfill_models, @@ -353,14 +352,14 @@ def _build_dag(self) -> DAG[SnapshotId]: def _build_restatements( self, dag: DAG[SnapshotId], earliest_interval_start: TimeLike - ) -> t.Tuple[t.Dict[SnapshotId, Interval], t.Dict[SnapshotId, t.List[SnapshotId]]]: + ) -> t.Dict[SnapshotId, Interval]: restate_models = self._restate_models if restate_models == set(): # This is a warning but we print this as error since the Console is lacking API for warnings. self._console.log_error( "Provided restated models do not match any models. No models will be included in plan." ) - return {}, {} + return {} restatements: t.Dict[SnapshotId, Interval] = {} forward_only_preview_needed = self._forward_only_preview_needed @@ -384,7 +383,7 @@ def _build_restatements( is_preview = True if not restate_models: - return {}, {} + return {} start = self._start or earliest_interval_start end = self._end or now() @@ -394,7 +393,6 @@ def _build_restatements( if model_fqn not in self._model_fqn_to_snapshot: raise PlanError(f"Cannot restate model '{model_fqn}'. Model does not exist.") - restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} # Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's # restatement range that it's downstream dependencies all expand their restatement ranges as well. for s_id in dag: @@ -430,13 +428,6 @@ def _build_restatements( logger.info("Skipping restatement for model '%s'", snapshot.name) continue - if snapshot.name in restate_models: - restatement_triggers[s_id] = [s_id] - if restating_parents: - restatement_triggers[s_id] = restatement_triggers.get(s_id, []) + [ - s.snapshot_id for s in restating_parents - ] - possible_intervals = { restatements[p.snapshot_id] for p in restating_parents if p.is_incremental } @@ -465,7 +456,7 @@ def _build_restatements( restatements[s_id] = (snapshot_start, snapshot_end) - return restatements, restatement_triggers + return restatements def _build_directly_and_indirectly_modified( self, dag: DAG[SnapshotId] diff --git a/sqlmesh/core/plan/definition.py b/sqlmesh/core/plan/definition.py index 62b8544939..300ac62faf 100644 --- a/sqlmesh/core/plan/definition.py +++ b/sqlmesh/core/plan/definition.py @@ -58,7 +58,6 @@ class Plan(PydanticModel, frozen=True): deployability_index: DeployabilityIndex restatements: t.Dict[SnapshotId, Interval] - restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} start_override_per_model: t.Optional[t.Dict[str, datetime]] end_override_per_model: t.Optional[t.Dict[str, datetime]] @@ -257,7 +256,6 @@ def to_evaluatable(self) -> EvaluatablePlan: skip_backfill=self.skip_backfill, empty_backfill=self.empty_backfill, restatements={s.name: i for s, i in self.restatements.items()}, - restatement_triggers=self.restatement_triggers, is_dev=self.is_dev, allow_destructive_models=self.allow_destructive_models, forward_only=self.forward_only, @@ -300,7 +298,6 @@ class EvaluatablePlan(PydanticModel): skip_backfill: bool empty_backfill: bool restatements: t.Dict[str, Interval] - restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} is_dev: bool allow_destructive_models: t.Set[str] forward_only: bool diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index ba237f9be6..46142b7eeb 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -37,7 +37,6 @@ SnapshotCreationFailedError, SnapshotNameVersion, ) -from sqlmesh.core.snapshot.definition import SnapshotEvaluationTriggers from sqlmesh.utils import to_snake_case from sqlmesh.core.state_sync import StateSync from sqlmesh.utils import CorrelationId @@ -245,27 +244,6 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla self.console.log_success("SKIP: No model batches to execute") return - directly_modified_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} - for parent, children in plan.indirectly_modified_snapshots.items(): - parent_id = stage.all_snapshots[parent].snapshot_id - directly_modified_triggers[parent_id] = directly_modified_triggers.get( - parent_id, [] - ) + [parent_id] - for child in children: - directly_modified_triggers[child] = directly_modified_triggers.get(child, []) + [ - parent_id - ] - directly_modified_triggers = { - k: list(dict.fromkeys(v)) for k, v in directly_modified_triggers.items() - } - snapshot_evaluation_triggers = { - s_id: SnapshotEvaluationTriggers( - directly_modified_triggers=directly_modified_triggers.get(s_id, []), - restatement_triggers=plan.restatement_triggers.get(s_id, []), - ) - for s_id in [s.snapshot_id for s in stage.all_snapshots.values()] - } - scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator) errors, _ = scheduler.run_merged_intervals( merged_intervals=stage.snapshot_to_intervals, diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 4b28ec5d6e..37b9acc275 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -332,8 +332,6 @@ class SnapshotEvaluationTriggers(PydanticModel): cron_ready: t.Optional[bool] = None auto_restatement_triggers: t.List[SnapshotId] = [] select_snapshot_triggers: t.List[SnapshotId] = [] - directly_modified_triggers: t.List[SnapshotId] = [] - restatement_triggers: t.List[SnapshotId] = [] class SnapshotInfoMixin(ModelKindMixin): diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index e3ba5a1a85..db2a5f7cf2 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -26,7 +26,6 @@ from sqlmesh import CustomMaterialization -import sqlmesh from sqlmesh.cli.project_init import init_example_project from sqlmesh.core import constants as c from sqlmesh.core import dialect as d @@ -1868,27 +1867,6 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt context, plan = init_and_plan_context("examples/sushi") context.apply(plan) - # modify 3 models - # - 2 breaking changes for testing plan directly modified triggers - # - 1 adding an auto-restatement for subsequent `run` test - marketing = context.get_model("sushi.marketing") - marketing_kwargs = { - **marketing.dict(), - "query": d.parse_one( - f"{marketing.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb" - ), - } - context.upsert_model(SqlModel.parse_obj(marketing_kwargs)) - - customers = context.get_model("sushi.customers") - customers_kwargs = { - **customers.dict(), - "query": d.parse_one( - f"{customers.query.sql(dialect='duckdb')} ORDER BY customer_id", dialect="duckdb" - ), - } - context.upsert_model(SqlModel.parse_obj(customers_kwargs)) - # add auto restatement to orders orders = context.get_model("sushi.orders") orders_kind = { @@ -1901,67 +1879,8 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt } context.upsert_model(PythonModel.parse_obj(orders_kwargs)) - spy = mocker.spy(sqlmesh.core.scheduler.Scheduler, "run_merged_intervals") - context.plan(auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full()) - # PLAN: directly modified triggers - actual_triggers = spy.call_args.kwargs["snapshot_evaluation_triggers"] - actual_triggers_name = { - k.name: sorted([s.name for s in v.directly_modified_triggers]) - for k, v in actual_triggers.items() - if v.directly_modified_triggers - } - marketing_name = '"memory"."sushi"."marketing"' - customers_name = '"memory"."sushi"."customers"' - marketing_customers_names = sorted([marketing_name, customers_name]) - children_names = [ - f'"memory"."sushi"."{model}"' - for model in { - "waiter_as_customer_by_day", - "active_customers", - "count_customers_active", - "count_customers_inactive", - } - ] - assert actual_triggers_name == { - marketing_name: [marketing_name], - customers_name: [customers_name], - **{k: marketing_customers_names for k in children_names}, - } - - # PLAN: restatement triggers - spy.reset_mock() - context.plan( - restate_models=[ - '"memory"."sushi"."marketing"', - '"memory"."sushi"."order_items"', - '"memory"."sushi"."waiter_revenue_by_day"', - ], - auto_apply=True, - no_prompts=True, - ) - - order_items_name = '"memory"."sushi"."order_items"' - waiter_revenue_by_day_name = '"memory"."sushi"."waiter_revenue_by_day"' - actual_triggers = spy.call_args.kwargs["snapshot_evaluation_triggers"] - actual_triggers_name = { - k.name: sorted([s.name for s in v.restatement_triggers]) - for k, v in actual_triggers.items() - if v.restatement_triggers - } - - assert sorted(actual_triggers_name[waiter_revenue_by_day_name]) == sorted( - [waiter_revenue_by_day_name, order_items_name] - ) - assert actual_triggers_name[order_items_name] == [order_items_name] - assert actual_triggers_name['"memory"."sushi"."top_waiters"'] == [waiter_revenue_by_day_name] - assert actual_triggers_name['"memory"."sushi"."customer_revenue_by_day"'] == [order_items_name] - assert actual_triggers_name['"memory"."sushi"."customer_revenue_lifetime"'] == [ - order_items_name - ] - - # RUN: select and auto-restatement triggers # User selects top_waiters and waiter_revenue_by_day, others added as auto-upstream selected_models = {"top_waiters", "waiter_revenue_by_day"} selected_models_auto_upstream = {"order_items", "orders", "items"} From 2f276091a07868a0d3e65c3a418a9d83a33888bf Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Mon, 11 Aug 2025 16:23:08 -0500 Subject: [PATCH 09/13] Infer cron ready from auto-restatement intervals --- sqlmesh/core/plan/stages.py | 2 +- sqlmesh/core/scheduler.py | 29 ++++++++++++++--------------- tests/core/test_scheduler.py | 10 ++++------ web/server/api/endpoints/plan.py | 2 +- 4 files changed, 20 insertions(+), 23 deletions(-) diff --git a/sqlmesh/core/plan/stages.py b/sqlmesh/core/plan/stages.py index c1eb3f9927..440719e89c 100644 --- a/sqlmesh/core/plan/stages.py +++ b/sqlmesh/core/plan/stages.py @@ -553,7 +553,7 @@ def _missing_intervals( snapshots_by_name: t.Dict[str, Snapshot], deployability_index: DeployabilityIndex, ) -> SnapshotToIntervals: - missing_intervals, _ = merged_missing_intervals( + missing_intervals = merged_missing_intervals( snapshots=snapshots_by_name.values(), start=plan.start, end=plan.end, diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index aca4fec167..e2360cbd11 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -28,10 +28,11 @@ snapshots_to_dag, Intervals, ) -from sqlmesh.core.snapshot.definition import check_ready_intervals from sqlmesh.core.snapshot.definition import ( Interval, SnapshotEvaluationTriggers, + SnapshotIntervals, + check_ready_intervals, expand_range, parent_snapshots_by_name, ) @@ -147,7 +148,7 @@ def merged_missing_intervals( ignore_cron: bool = False, end_bounded: bool = False, selected_snapshots: t.Optional[t.Set[str]] = None, - ) -> t.Tuple[SnapshotToIntervals, t.List[SnapshotId]]: + ) -> SnapshotToIntervals: """Find the largest contiguous date interval parameters based only on what is missing. For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found, @@ -169,9 +170,9 @@ def merged_missing_intervals( selected_snapshots: A set of snapshot names to run. If not provided, all snapshots will be run. Returns: - A tuple containing a dict containing all snapshots needing to be run with their associated interval params and a list of snapshots that are ready to run based on their naive cron schedule (ignoring plan/run context and other attributes). + A dict containing all snapshots needing to be run with their associated interval params. """ - snapshots_to_intervals, snapshots_naive_cron_ready = merged_missing_intervals( + snapshots_to_intervals = merged_missing_intervals( snapshots=self.snapshot_per_version.values(), start=start, end=end, @@ -189,7 +190,7 @@ def merged_missing_intervals( snapshots_to_intervals = { s: i for s, i in snapshots_to_intervals.items() if s.name in selected_snapshots } - return snapshots_to_intervals, snapshots_naive_cron_ready + return snapshots_to_intervals def evaluate( self, @@ -748,6 +749,7 @@ def _run_or_audit( for s_id, interval in (remove_intervals or {}).items(): self.snapshots[s_id].remove_interval(interval) + auto_restated_intervals: t.List[SnapshotIntervals] = [] auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} if auto_restatement_enabled: auto_restated_intervals, auto_restatement_triggers = apply_auto_restatements( @@ -757,8 +759,9 @@ def _run_or_audit( self.state_sync.update_auto_restatements( {s.name_version: s.next_auto_restatement_ts for s in self.snapshots.values()} ) + auto_restated_snapshots = {snapshot.snapshot_id for snapshot in auto_restated_intervals} - merged_intervals, snapshots_naive_cron_ready = self.merged_missing_intervals( + merged_intervals = self.merged_missing_intervals( start, end, execution_time, @@ -801,7 +804,7 @@ def _run_or_audit( all_snapshot_triggers: t.Dict[SnapshotId, SnapshotEvaluationTriggers] = { s_id: SnapshotEvaluationTriggers( ignore_cron_flag=ignore_cron, - cron_ready=s_id in snapshots_naive_cron_ready, + cron_ready=s_id not in auto_restated_snapshots, auto_restatement_triggers=auto_restatement_triggers.get(s_id, []), select_snapshot_triggers=select_snapshot_triggers.get(s_id, []), ) @@ -969,7 +972,7 @@ def merged_missing_intervals( end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, ignore_cron: bool = False, end_bounded: bool = False, -) -> t.Tuple[SnapshotToIntervals, t.List[SnapshotId]]: +) -> SnapshotToIntervals: """Find the largest contiguous date interval parameters based only on what is missing. For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found, @@ -1019,7 +1022,7 @@ def compute_interval_params( end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, ignore_cron: bool = False, end_bounded: bool = False, -) -> t.Tuple[SnapshotToIntervals, t.List[SnapshotId]]: +) -> SnapshotToIntervals: """Find the largest contiguous date interval parameters based only on what is missing. For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found, @@ -1041,7 +1044,7 @@ def compute_interval_params( allow_partials, and other attributes that could cause the intervals to exceed the target end date. Returns: - A tuple containing a dict containing all snapshots needing to be run with their associated interval params and a list of snapshots that are ready to run based on their naive cron schedule (ignoring plan/run context and other attributes). + A dict containing all snapshots needing to be run with their associated interval params. """ snapshot_merged_intervals = {} @@ -1069,11 +1072,7 @@ def compute_interval_params( contiguous_batch.append((next_batch[0][0], next_batch[-1][-1])) snapshot_merged_intervals[snapshot] = contiguous_batch - snapshots_naive_cron_ready = [ - snap.snapshot_id for snap in missing_intervals(snapshots, execution_time=execution_time) - ] - - return snapshot_merged_intervals, snapshots_naive_cron_ready + return snapshot_merged_intervals def interval_diff( diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 571242082f..d58003981b 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -59,9 +59,7 @@ def test_interval_params(scheduler: Scheduler, sushi_context_fixed_date: Context start_ds = "2022-01-01" end_ds = "2022-02-05" - interval_params, _ = compute_interval_params( - [orders, waiter_revenue], start=start_ds, end=end_ds - ) + interval_params = compute_interval_params([orders, waiter_revenue], start=start_ds, end=end_ds) assert interval_params == { orders: [ (to_timestamp(start_ds), to_timestamp("2022-02-06")), @@ -82,7 +80,7 @@ def _get_batched_missing_intervals( end: TimeLike, execution_time: t.Optional[TimeLike] = None, ) -> SnapshotToIntervals: - merged_intervals, _ = scheduler.merged_missing_intervals(start, end, execution_time) + merged_intervals = scheduler.merged_missing_intervals(start, end, execution_time) return scheduler.batch_intervals(merged_intervals, mocker.Mock(), mocker.Mock()) return _get_batched_missing_intervals @@ -94,7 +92,7 @@ def test_interval_params_nonconsecutive(scheduler: Scheduler, orders: Snapshot): orders.add_interval("2022-01-10", "2022-01-15") - interval_params, _ = compute_interval_params([orders], start=start_ds, end=end_ds) + interval_params = compute_interval_params([orders], start=start_ds, end=end_ds) assert interval_params == { orders: [ (to_timestamp(start_ds), to_timestamp("2022-01-10")), @@ -111,7 +109,7 @@ def test_interval_params_missing(scheduler: Scheduler, sushi_context_fixed_date: start_ds = "2022-01-01" end_ds = "2022-03-01" - interval_params, _ = compute_interval_params( + interval_params = compute_interval_params( sushi_context_fixed_date.snapshots.values(), start=start_ds, end=end_ds ) assert interval_params[waiters] == [ diff --git a/web/server/api/endpoints/plan.py b/web/server/api/endpoints/plan.py index 39b0e810c6..ba4feb34d3 100644 --- a/web/server/api/endpoints/plan.py +++ b/web/server/api/endpoints/plan.py @@ -132,7 +132,7 @@ def _get_plan_changes(context: Context, plan: Plan) -> models.PlanChanges: def _get_plan_backfills(context: Context, plan: Plan) -> t.Dict[str, t.Any]: """Get plan backfills""" - merged_intervals, _ = context.scheduler().merged_missing_intervals() + merged_intervals = context.scheduler().merged_missing_intervals() batches = context.scheduler().batch_intervals(merged_intervals, None, EnvironmentNamingInfo()) tasks = {snapshot.name: len(intervals) for snapshot, intervals in batches.items()} snapshots = plan.context_diff.snapshots From 50ca148db55c361be597576b333d084faf070d48 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Tue, 12 Aug 2025 13:17:33 -0500 Subject: [PATCH 10/13] Report auto-restatement triggers only --- sqlmesh/core/console.py | 30 ++++------- sqlmesh/core/context.py | 19 ++++--- sqlmesh/core/plan/stages.py | 3 +- sqlmesh/core/scheduler.py | 48 ++---------------- sqlmesh/core/snapshot/definition.py | 22 +++----- tests/core/test_integration.py | 79 ++++++++++++++++------------- tests/core/test_scheduler.py | 11 ++-- tests/core/test_snapshot.py | 12 +++-- web/server/console.py | 5 +- 9 files changed, 89 insertions(+), 140 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index b043f2ab6b..861f7ca816 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -37,12 +37,7 @@ SnapshotId, SnapshotInfoLike, ) -from sqlmesh.core.snapshot.definition import ( - Interval, - Intervals, - SnapshotTableInfo, - SnapshotEvaluationTriggers, -) +from sqlmesh.core.snapshot.definition import Interval, Intervals, SnapshotTableInfo from sqlmesh.core.test import ModelTest from sqlmesh.utils import rich as srich from sqlmesh.utils import Verbosity @@ -433,7 +428,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Updates the snapshot evaluation progress.""" @@ -581,7 +576,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: pass @@ -1063,7 +1058,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Update the snapshot evaluation progress.""" if ( @@ -3647,7 +3642,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: view_name, loaded_batches = self.evaluation_batch_progress[snapshot.snapshot_id] @@ -3817,22 +3812,15 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: message = f"Evaluated {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" - if snapshot_evaluation_triggers: - if snapshot_evaluation_triggers.ignore_cron_flag is not None: - message += f" | ignore_cron_flag={snapshot_evaluation_triggers.ignore_cron_flag}" - if snapshot_evaluation_triggers.cron_ready is not None: - message += f" | cron_ready={snapshot_evaluation_triggers.cron_ready}" - if snapshot_evaluation_triggers.auto_restatement_triggers: - message += f" | auto_restatement_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.auto_restatement_triggers)}" - if snapshot_evaluation_triggers.select_snapshot_triggers: - message += f" | select_snapshot_triggers={','.join(trigger.name for trigger in snapshot_evaluation_triggers.select_snapshot_triggers)}" + if auto_restatement_triggers: + message += f" | Auto-restatement triggers {', '.join(trigger.name for trigger in auto_restatement_triggers)}" if audit_only: - message = f"Audited {snapshot.name} duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" + message = f"Audited {snapshot.name} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" self._write(message) diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 8a12136f83..9022f3f069 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -2307,9 +2307,11 @@ def check_intervals( } if select_models: - selected, _ = self._select_models_for_run(select_models, True, snapshots.values()) + selected: t.Collection[str] = self._select_models_for_run( + select_models, True, snapshots.values() + ) else: - selected = set(snapshots.keys()) + selected = snapshots.keys() results = {} execution_context = self.execution_context(snapshots=snapshots) @@ -2459,9 +2461,8 @@ def _run( scheduler = self.scheduler(environment=environment) snapshots = scheduler.snapshots - select_models_auto_upstream = None if select_models is not None: - select_models, select_models_auto_upstream = self._select_models_for_run( + select_models = self._select_models_for_run( select_models, no_auto_upstream, snapshots.values() ) @@ -2473,7 +2474,6 @@ def _run( ignore_cron=ignore_cron, circuit_breaker=circuit_breaker, selected_snapshots=select_models, - selected_snapshots_auto_upstream=select_models_auto_upstream, auto_restatement_enabled=environment.lower() == c.PROD, run_environment_statements=True, ) @@ -2889,7 +2889,7 @@ def _select_models_for_run( select_models: t.Collection[str], no_auto_upstream: bool, snapshots: t.Collection[Snapshot], - ) -> t.Tuple[t.Set[str], t.Set[str]]: + ) -> t.Set[str]: models: UniqueKeyDict[str, Model] = UniqueKeyDict( "models", **{s.name: s.model for s in snapshots if s.is_model} ) @@ -2898,10 +2898,9 @@ def _select_models_for_run( dag.add(fqn, model.depends_on) model_selector = self._new_selector(models=models, dag=dag) result = set(model_selector.expand_model_selections(select_models)) - if no_auto_upstream: - return result, set() - result_with_upstream = set(dag.subdag(*result)) - return result_with_upstream, result_with_upstream - result + if not no_auto_upstream: + result = set(dag.subdag(*result)) + return result @cached_property def _project_type(self) -> str: diff --git a/sqlmesh/core/plan/stages.py b/sqlmesh/core/plan/stages.py index 440719e89c..82223dd807 100644 --- a/sqlmesh/core/plan/stages.py +++ b/sqlmesh/core/plan/stages.py @@ -553,7 +553,7 @@ def _missing_intervals( snapshots_by_name: t.Dict[str, Snapshot], deployability_index: DeployabilityIndex, ) -> SnapshotToIntervals: - missing_intervals = merged_missing_intervals( + return merged_missing_intervals( snapshots=snapshots_by_name.values(), start=plan.start, end=plan.end, @@ -568,7 +568,6 @@ def _missing_intervals( start_override_per_model=plan.start_override_per_model, end_override_per_model=plan.end_override_per_model, ) - return missing_intervals def _get_audit_only_snapshots( self, new_snapshots: t.Dict[SnapshotId, Snapshot] diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index e2360cbd11..8565412164 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -28,11 +28,10 @@ snapshots_to_dag, Intervals, ) +from sqlmesh.core.snapshot.definition import check_ready_intervals from sqlmesh.core.snapshot.definition import ( Interval, - SnapshotEvaluationTriggers, SnapshotIntervals, - check_ready_intervals, expand_range, parent_snapshots_by_name, ) @@ -168,9 +167,6 @@ def merged_missing_intervals( end_bounded: If set to true, the returned intervals will be bounded by the target end date, disregarding lookback, allow_partials, and other attributes that could cause the intervals to exceed the target end date. selected_snapshots: A set of snapshot names to run. If not provided, all snapshots will be run. - - Returns: - A dict containing all snapshots needing to be run with their associated interval params. """ snapshots_to_intervals = merged_missing_intervals( snapshots=self.snapshot_per_version.values(), @@ -267,7 +263,6 @@ def run( ignore_cron: bool = False, end_bounded: bool = False, selected_snapshots: t.Optional[t.Set[str]] = None, - selected_snapshots_auto_upstream: t.Optional[t.Set[str]] = None, circuit_breaker: t.Optional[t.Callable[[], bool]] = None, deployability_index: t.Optional[DeployabilityIndex] = None, auto_restatement_enabled: bool = False, @@ -284,7 +279,6 @@ def run( ignore_cron=ignore_cron, end_bounded=end_bounded, selected_snapshots=selected_snapshots, - selected_snapshots_auto_upstream=selected_snapshots_auto_upstream, circuit_breaker=circuit_breaker, deployability_index=deployability_index, auto_restatement_enabled=auto_restatement_enabled, @@ -422,6 +416,7 @@ def run_merged_intervals( selected_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, run_environment_statements: bool = False, audit_only: bool = False, + restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None, auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}, ) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]: """Runs precomputed batches of missing intervals. @@ -539,9 +534,7 @@ def run_node(node: SchedulingUnit) -> None: evaluation_duration_ms, num_audits - num_audits_failed, num_audits_failed, - snapshot_evaluation_triggers=snapshot_evaluation_triggers.get( - snapshot.snapshot_id - ), + auto_restatement_triggers=auto_restatement_triggers.get(snapshot.snapshot_id), ) elif isinstance(node, CreateNode): self.snapshot_evaluator.create_snapshot( @@ -694,7 +687,6 @@ def _run_or_audit( ignore_cron: bool = False, end_bounded: bool = False, selected_snapshots: t.Optional[t.Set[str]] = None, - selected_snapshots_auto_upstream: t.Optional[t.Set[str]] = None, circuit_breaker: t.Optional[t.Callable[[], bool]] = None, deployability_index: t.Optional[DeployabilityIndex] = None, auto_restatement_enabled: bool = False, @@ -718,7 +710,6 @@ def _run_or_audit( end_bounded: If set to true, the evaluated intervals will be bounded by the target end date, disregarding lookback, allow_partials, and other attributes that could cause the intervals to exceed the target end date. selected_snapshots: A set of snapshot names to run. If not provided, all snapshots will be run. - selected_snapshots_auto_upstream: The set of selected_snapshots that were automatically added because they're upstream of a selected snapshot. circuit_breaker: An optional handler which checks if the run should be aborted. deployability_index: Determines snapshots that are deployable in the context of this render. auto_restatement_enabled: Whether to enable auto restatements. @@ -777,38 +768,9 @@ def _run_or_audit( return CompletionStatus.NOTHING_TO_DO merged_intervals_snapshots = {snapshot.snapshot_id for snapshot in merged_intervals} - select_snapshot_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} - if selected_snapshots and selected_snapshots_auto_upstream: - # actually selected snapshots are their own triggers - selected_snapshots_no_auto_upstream = ( - selected_snapshots - selected_snapshots_auto_upstream - ) - select_snapshot_triggers = { - s_id: [s_id] - for s_id in [ - snapshot_id - for snapshot_id in merged_intervals_snapshots - if snapshot_id.name in selected_snapshots_no_auto_upstream - ] - } - # trace upstream by walking downstream on reversed dag - reversed_dag = snapshots_to_dag(self.snapshots.values()).reversed - for s_id in reversed_dag: - if s_id in merged_intervals_snapshots: - triggers = select_snapshot_triggers.get(s_id, []) - for parent_s_id in reversed_dag.graph.get(s_id, set()): - triggers.extend(select_snapshot_triggers.get(parent_s_id, [])) - select_snapshot_triggers[s_id] = list(dict.fromkeys(triggers)) - - all_snapshot_triggers: t.Dict[SnapshotId, SnapshotEvaluationTriggers] = { - s_id: SnapshotEvaluationTriggers( - ignore_cron_flag=ignore_cron, - cron_ready=s_id not in auto_restated_snapshots, - auto_restatement_triggers=auto_restatement_triggers.get(s_id, []), - select_snapshot_triggers=select_snapshot_triggers.get(s_id, []), - ) - for s_id in merged_intervals_snapshots + auto_restatement_triggers_dict: t.Dict[SnapshotId, t.List[SnapshotId]] = { + s_id: auto_restatement_triggers.get(s_id, []) for s_id in merged_intervals_snapshots } errors, _ = self.run_merged_intervals( diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 37b9acc275..45740d9810 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -21,7 +21,7 @@ from sqlmesh.core.model import Model, ModelKindMixin, ModelKindName, ViewKind, CustomKind from sqlmesh.core.model.definition import _Model from sqlmesh.core.node import IntervalUnit, NodeType -from sqlmesh.utils import sanitize_name +from sqlmesh.utils import sanitize_name, unique from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import ( TimeLike, @@ -327,13 +327,6 @@ def table_name_for_environment( return table -class SnapshotEvaluationTriggers(PydanticModel): - ignore_cron_flag: t.Optional[bool] = None - cron_ready: t.Optional[bool] = None - auto_restatement_triggers: t.List[SnapshotId] = [] - select_snapshot_triggers: t.List[SnapshotId] = [] - - class SnapshotInfoMixin(ModelKindMixin): name: str dev_version_: t.Optional[str] @@ -2229,14 +2222,15 @@ def apply_auto_restatements( # auto-restated snapshot is its own trigger upstream_triggers = [s_id] + else: + # inherit each parent's auto-restatement triggers (if any) + for parent_s_id in snapshot.parents: + if parent_s_id in auto_restatement_triggers: + upstream_triggers.extend(auto_restatement_triggers[parent_s_id]) - for parent_s_id in snapshot.parents: - if parent_s_id in auto_restatement_triggers: - upstream_triggers.extend(auto_restatement_triggers[parent_s_id]) - - # remove duplicate triggers + # remove duplicate triggers, retaining order and keeping first seen of duplicates if upstream_triggers: - auto_restatement_triggers[s_id] = list(dict.fromkeys(upstream_triggers)) + auto_restatement_triggers[s_id] = unique(upstream_triggers) if auto_restated_intervals: auto_restated_interval_start = sys.maxsize diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py index db2a5f7cf2..827d84e8b9 100644 --- a/tests/core/test_integration.py +++ b/tests/core/test_integration.py @@ -1867,7 +1867,7 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt context, plan = init_and_plan_context("examples/sushi") context.apply(plan) - # add auto restatement to orders + # auto-restatement triggers orders = context.get_model("sushi.orders") orders_kind = { **orders.kind.dict(), @@ -1879,57 +1879,63 @@ def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixt } context.upsert_model(PythonModel.parse_obj(orders_kwargs)) - context.plan(auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full()) + order_items = context.get_model("sushi.order_items") + order_items_kind = { + **order_items.kind.dict(), + "auto_restatement_cron": "@hourly", + } + order_items_kwargs = { + **order_items.dict(), + "kind": order_items_kind, + } + context.upsert_model(PythonModel.parse_obj(order_items_kwargs)) - # User selects top_waiters and waiter_revenue_by_day, others added as auto-upstream - selected_models = {"top_waiters", "waiter_revenue_by_day"} - selected_models_auto_upstream = {"order_items", "orders", "items"} - selected_snapshots = { - f'"memory"."sushi"."{model}"' for model in selected_models | selected_models_auto_upstream + waiter_revenue_by_day = context.get_model("sushi.waiter_revenue_by_day") + waiter_revenue_by_day_kind = { + **waiter_revenue_by_day.kind.dict(), + "auto_restatement_cron": "@hourly", } - selected_snapshots_auto_upstream = selected_snapshots - { - f'"memory"."sushi"."{model}"' for model in selected_models + waiter_revenue_by_day_kwargs = { + **waiter_revenue_by_day.dict(), + "kind": waiter_revenue_by_day_kind, } + context.upsert_model(SqlModel.parse_obj(waiter_revenue_by_day_kwargs)) + + context.plan(auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full()) scheduler = context.scheduler() - run_merged_intervals_mock = mocker.patch.object( - scheduler, "run_merged_intervals", return_value=([], []) - ) + + import sqlmesh + + spy = mocker.spy(sqlmesh.core.scheduler.Scheduler, "run_merged_intervals") with time_machine.travel("2023-01-09 00:00:01 UTC"): scheduler.run( environment=c.PROD, - selected_snapshots=selected_snapshots, - selected_snapshots_auto_upstream=selected_snapshots_auto_upstream, start="2023-01-01", auto_restatement_enabled=True, ) - assert run_merged_intervals_mock.called + assert spy.called - actual_triggers = run_merged_intervals_mock.call_args.kwargs["snapshot_evaluation_triggers"] - - # validate ignore_cron not passed and all model crons ready - assert all( - not trigger.ignore_cron_flag and trigger.cron_ready for trigger in actual_triggers.values() - ) + actual_triggers = spy.call_args.kwargs["auto_restatement_triggers"] + actual_triggers = {k: v for k, v in actual_triggers.items() if v} + assert len(actual_triggers) == 12 for id, trigger in actual_triggers.items(): - # top_waiters is its own trigger, waiter_revenue_by_day is upstream of it, everyone else is upstream of both - select_triggers = [t.name for t in trigger.select_snapshot_triggers] - assert ( - select_triggers == ['"memory"."sushi"."top_waiters"'] - if id.name == '"memory"."sushi"."top_waiters"' - else ['"memory"."sushi"."waiter_revenue_by_day"', '"memory"."sushi"."top_waiters"'] - ) + model_name = id.name.replace('"memory"."sushi".', "").replace('"', "") + auto_restatement_triggers = [ + t.name.replace('"memory"."sushi".', "").replace('"', "") for t in trigger + ] - # everyone other than items is downstream of orders - auto_restatement_triggers = [t.name for t in trigger.auto_restatement_triggers] - assert ( - auto_restatement_triggers == [] - if id.name == '"memory"."sushi"."items"' - else ['"memory"."sushi"."orders"'] - ) + if model_name in ("orders", "order_items", "waiter_revenue_by_day"): + assert auto_restatement_triggers == [model_name] + elif model_name in ("customer_revenue_lifetime", "customer_revenue_by_day"): + assert sorted(auto_restatement_triggers) == sorted(["orders", "order_items"]) + elif model_name == "top_waiters": + assert auto_restatement_triggers == ["waiter_revenue_by_day"] + else: + assert auto_restatement_triggers == ["orders"] @time_machine.travel("2023-01-08 15:00:00 UTC") @@ -7032,6 +7038,7 @@ def plan_with_output(ctx: Context, environment: str): assert "New environment `dev` will be created from `prod`" in output.stdout assert "Differences from the `prod` environment" in output.stdout + stdout_rstrip = "\n".join([line.rstrip() for line in output.stdout.split("\n")]) assert ( """MODEL ( name test.a, @@ -7041,7 +7048,7 @@ def plan_with_output(ctx: Context, environment: str): SELECT - 5 AS col + 10 AS col""" - in output.stdout + in stdout_rstrip ) # Case 6: Ensure that target environment and create_from environment are not the same diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index d58003981b..b74aa3480e 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -59,8 +59,7 @@ def test_interval_params(scheduler: Scheduler, sushi_context_fixed_date: Context start_ds = "2022-01-01" end_ds = "2022-02-05" - interval_params = compute_interval_params([orders, waiter_revenue], start=start_ds, end=end_ds) - assert interval_params == { + assert compute_interval_params([orders, waiter_revenue], start=start_ds, end=end_ds) == { orders: [ (to_timestamp(start_ds), to_timestamp("2022-02-06")), ], @@ -92,8 +91,7 @@ def test_interval_params_nonconsecutive(scheduler: Scheduler, orders: Snapshot): orders.add_interval("2022-01-10", "2022-01-15") - interval_params = compute_interval_params([orders], start=start_ds, end=end_ds) - assert interval_params == { + assert compute_interval_params([orders], start=start_ds, end=end_ds) == { orders: [ (to_timestamp(start_ds), to_timestamp("2022-01-10")), (to_timestamp("2022-01-16"), to_timestamp("2022-02-06")), @@ -109,10 +107,9 @@ def test_interval_params_missing(scheduler: Scheduler, sushi_context_fixed_date: start_ds = "2022-01-01" end_ds = "2022-03-01" - interval_params = compute_interval_params( + assert compute_interval_params( sushi_context_fixed_date.snapshots.values(), start=start_ds, end=end_ds - ) - assert interval_params[waiters] == [ + )[waiters] == [ (to_timestamp(start_ds), to_timestamp("2022-03-02")), ] diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 8a8a349892..db61b9cabf 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -3280,6 +3280,12 @@ def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot): def test_auto_restatement_triggers(make_snapshot): + # Auto restatements: + # a, c, d + # dag: + # a -> b + # a -> c + # [b, c, d] -> e model_a = SqlModel( name="test_model_a", kind=IncrementalByTimeRangeKind( @@ -3372,12 +3378,10 @@ def test_auto_restatement_triggers(make_snapshot): ) assert auto_restatement_triggers[snapshot_a.snapshot_id] == [snapshot_a.snapshot_id] + assert auto_restatement_triggers[snapshot_c.snapshot_id] == [snapshot_c.snapshot_id] assert auto_restatement_triggers[snapshot_d.snapshot_id] == [snapshot_d.snapshot_id] assert auto_restatement_triggers[snapshot_b.snapshot_id] == [snapshot_a.snapshot_id] - assert auto_restatement_triggers[snapshot_c.snapshot_id] == [ - snapshot_c.snapshot_id, - snapshot_a.snapshot_id, - ] + # a via b, c and d directly assert sorted(auto_restatement_triggers[snapshot_e.snapshot_id]) == [ snapshot_a.snapshot_id, snapshot_c.snapshot_id, diff --git a/web/server/console.py b/web/server/console.py index 5af93864f6..902a85418c 100644 --- a/web/server/console.py +++ b/web/server/console.py @@ -9,8 +9,7 @@ from sqlmesh.core.console import TerminalConsole from sqlmesh.core.environment import EnvironmentNamingInfo from sqlmesh.core.plan.definition import EvaluatablePlan -from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo -from sqlmesh.core.snapshot.definition import SnapshotEvaluationTriggers +from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo, SnapshotId from sqlmesh.core.test import ModelTest from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.utils.date import now_timestamp @@ -143,7 +142,7 @@ def update_snapshot_evaluation_progress( num_audits_passed: int, num_audits_failed: int, audit_only: bool = False, - snapshot_evaluation_triggers: t.Optional[SnapshotEvaluationTriggers] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: if audit_only: return From 44ebb55c5130fbf0924b3f82e3b339ae16f13897 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Tue, 12 Aug 2025 15:03:32 -0500 Subject: [PATCH 11/13] Improve message formatting --- sqlmesh/core/console.py | 2 +- sqlmesh/core/scheduler.py | 19 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 861f7ca816..3b9fce7f4e 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -3817,7 +3817,7 @@ def update_snapshot_evaluation_progress( message = f"Evaluated {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" if auto_restatement_triggers: - message += f" | Auto-restatement triggers {', '.join(trigger.name for trigger in auto_restatement_triggers)}" + message += f" | auto_restatement_triggers=[{', '.join(trigger.name for trigger in auto_restatement_triggers)}]" if audit_only: message = f"Audited {snapshot.name} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 8565412164..63b2e67c06 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -31,7 +31,6 @@ from sqlmesh.core.snapshot.definition import check_ready_intervals from sqlmesh.core.snapshot.definition import ( Interval, - SnapshotIntervals, expand_range, parent_snapshots_by_name, ) @@ -740,17 +739,15 @@ def _run_or_audit( for s_id, interval in (remove_intervals or {}).items(): self.snapshots[s_id].remove_interval(interval) - auto_restated_intervals: t.List[SnapshotIntervals] = [] - auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} + all_auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} if auto_restatement_enabled: - auto_restated_intervals, auto_restatement_triggers = apply_auto_restatements( + auto_restated_intervals, all_auto_restatement_triggers = apply_auto_restatements( self.snapshots, execution_time ) self.state_sync.add_snapshots_intervals(auto_restated_intervals) self.state_sync.update_auto_restatements( {s.name_version: s.next_auto_restatement_ts for s in self.snapshots.values()} ) - auto_restated_snapshots = {snapshot.snapshot_id for snapshot in auto_restated_intervals} merged_intervals = self.merged_missing_intervals( start, @@ -767,11 +764,13 @@ def _run_or_audit( if not merged_intervals: return CompletionStatus.NOTHING_TO_DO - merged_intervals_snapshots = {snapshot.snapshot_id for snapshot in merged_intervals} - - auto_restatement_triggers_dict: t.Dict[SnapshotId, t.List[SnapshotId]] = { - s_id: auto_restatement_triggers.get(s_id, []) for s_id in merged_intervals_snapshots - } + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} + if all_auto_restatement_triggers: + merged_intervals_snapshots = {snapshot.snapshot_id for snapshot in merged_intervals} + auto_restatement_triggers = { + s_id: all_auto_restatement_triggers.get(s_id, []) + for s_id in merged_intervals_snapshots + } errors, _ = self.run_merged_intervals( merged_intervals=merged_intervals, From a8fb199573f19827b68952ac7e56e1e9717ec9c6 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Wed, 13 Aug 2025 08:17:45 -0500 Subject: [PATCH 12/13] Fix rebase --- sqlmesh/core/scheduler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 63b2e67c06..caf7cc534c 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -415,7 +415,6 @@ def run_merged_intervals( selected_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, run_environment_statements: bool = False, audit_only: bool = False, - restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None, auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}, ) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]: """Runs precomputed batches of missing intervals. @@ -782,7 +781,6 @@ def _run_or_audit( end=end, run_environment_statements=run_environment_statements, audit_only=audit_only, - restatements=remove_intervals, auto_restatement_triggers=auto_restatement_triggers, ) From 81388688f1cae8d0830ffbbd309b46f5a6648969 Mon Sep 17 00:00:00 2001 From: Trey Spiller Date: Wed, 20 Aug 2025 16:00:56 -0500 Subject: [PATCH 13/13] Fix style --- sqlmesh/core/scheduler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index caf7cc534c..8096ffece1 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -532,7 +532,9 @@ def run_node(node: SchedulingUnit) -> None: evaluation_duration_ms, num_audits - num_audits_failed, num_audits_failed, - auto_restatement_triggers=auto_restatement_triggers.get(snapshot.snapshot_id), + auto_restatement_triggers=auto_restatement_triggers.get( + snapshot.snapshot_id + ), ) elif isinstance(node, CreateNode): self.snapshot_evaluator.create_snapshot(