diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 7177efe927..f8a171730a 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -60,6 +60,36 @@ SchedulingUnit = t.Tuple[str, t.Tuple[Interval, int]] +class SignalListener: + def on_signal_register( + self, + snapshots: t.List[Snapshot], + ) -> None: + pass + + def on_signal_start( + self, + snapshot: Snapshot, + signal_name: str, + signal_index: int, + intervals: Intervals, + signal_kwargs: t.Dict[str, t.Optional[exp.Expression]], + ) -> None: + pass + + def on_signal_end( + self, + snapshot: Snapshot, + signal_name: str, + signal_index: int, + signal_kwargs: t.Dict[str, t.Optional[exp.Expression]], + intervals: Intervals, + ready_intervals: Intervals, + error: t.Optional[Exception] = None, + ) -> None: + pass + + class Scheduler: """Schedules and manages the evaluation of snapshots. @@ -86,6 +116,7 @@ def __init__( max_workers: int = 1, console: t.Optional[Console] = None, notification_target_manager: t.Optional[NotificationTargetManager] = None, + signal_listener: t.Optional[SignalListener] = None, ): self.state_sync = state_sync self.snapshots = {s.snapshot_id: s for s in snapshots} @@ -98,6 +129,7 @@ def __init__( self.notification_target_manager = ( notification_target_manager or NotificationTargetManager() ) + self.signal_listener = signal_listener or SignalListener() def merged_missing_intervals( self, @@ -305,6 +337,7 @@ def batch_intervals( ) for snapshot, intervals in merged_intervals.items() } + self.signal_listener.on_signal_register(list(merged_intervals.keys())) snapshot_batches = {} all_unready_intervals: t.Dict[str, set[Interval]] = {} for snapshot_id in dag: @@ -778,6 +811,15 @@ def _check_ready_intervals( signal_start_ts = time.perf_counter() + self.signal_listener.on_signal_start( + snapshot=snapshot, + signal_name=signal_name, + signal_index=signal_idx, + intervals=intervals_to_check, + signal_kwargs=kwargs, + ) + + error = None try: intervals = check_ready_intervals( signals.prepared_python_env[signal_name], @@ -789,9 +831,21 @@ def _check_ready_intervals( kwargs=kwargs, ) except SQLMeshError as e: - raise SignalEvalError( + error = SignalEvalError( f"{e} '{signal_name}' for '{snapshot.model.name}' at {snapshot.model._path}" ) + raise error + finally: + ready_intervals = merge_intervals(intervals) + self.signal_listener.on_signal_end( + snapshot=snapshot, + signal_name=signal_name, + signal_index=signal_idx, + signal_kwargs=kwargs, + intervals=intervals_to_check, + ready_intervals=ready_intervals, + error=error, + ) duration = time.perf_counter() - signal_start_ts @@ -800,7 +854,7 @@ def _check_ready_intervals( signal_name=signal_name, signal_idx=signal_idx, total_signals=len(signals.signals_to_kwargs), - ready_intervals=merge_intervals(intervals), + ready_intervals=ready_intervals, check_intervals=intervals_to_check, duration=duration, )