From 03dfb100f75f2c01ce8d16e6d8c635a0ee88dcc3 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Thu, 17 Jul 2025 20:54:06 +0300 Subject: [PATCH 1/3] Feat: Add signal listener and methods in scheduler --- sqlmesh/core/scheduler.py | 69 +++++++++++++++++++++++++++++++++++++-- sqlmesh/utils/pydantic.py | 10 ++++++ 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 7177efe927..d9176696d1 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -49,6 +49,7 @@ SQLMeshError, SignalEvalError, ) +from sqlmesh.utils.pydantic import serialize_expressions if t.TYPE_CHECKING: from sqlmesh.core.context import ExecutionContext @@ -60,6 +61,37 @@ SchedulingUnit = t.Tuple[str, t.Tuple[Interval, int]] +class SignalListener: + def on_signal_register( + self, + snapshot_id: SnapshotId, + signals: t.Dict[str, t.Dict[str, str]], + ) -> None: + pass + + def on_signal_start( + self, + snapshot_id: SnapshotId, + signal_name: str, + signal_index: int, + intervals: Intervals, + signal_kwargs: t.Dict[str, str], + ) -> None: + pass + + def on_signal_end( + self, + snapshot_id: SnapshotId, + signal_name: str, + signal_index: int, + signal_kwargs: t.Dict[str, str], + intervals: Intervals, + ready_intervals: Intervals, + error: t.Optional[Exception] = None, + ) -> None: + pass + + class Scheduler: """Schedules and manages the evaluation of snapshots. @@ -86,6 +118,7 @@ def __init__( max_workers: int = 1, console: t.Optional[Console] = None, notification_target_manager: t.Optional[NotificationTargetManager] = None, + signal_listener: t.Optional[SignalListener] = None, ): self.state_sync = state_sync self.snapshots = {s.snapshot_id: s for s in snapshots} @@ -98,6 +131,7 @@ def __init__( self.notification_target_manager = ( notification_target_manager or NotificationTargetManager() ) + self.signal_listener = signal_listener or SignalListener() def merged_missing_intervals( self, @@ -766,6 +800,16 @@ def _check_ready_intervals( if not (signals and signals.signals_to_kwargs): return intervals + signals_to_serialized_kwargs = { + signal_name: serialize_expressions(kwargs) + for signal_name, kwargs in signals.signals_to_kwargs.items() + } + + self.signal_listener.on_signal_register( + snapshot_id=snapshot.snapshot_id, + signals=signals_to_serialized_kwargs, + ) + self.console.start_signal_progress( snapshot, self.default_catalog, @@ -778,6 +822,15 @@ def _check_ready_intervals( signal_start_ts = time.perf_counter() + self.signal_listener.on_signal_start( + snapshot_id=snapshot.snapshot_id, + signal_name=signal_name, + signal_index=signal_idx, + intervals=intervals_to_check, + signal_kwargs=signals_to_serialized_kwargs[signal_name], + ) + + error = None try: intervals = check_ready_intervals( signals.prepared_python_env[signal_name], @@ -789,9 +842,21 @@ def _check_ready_intervals( kwargs=kwargs, ) except SQLMeshError as e: - raise SignalEvalError( + error = SignalEvalError( f"{e} '{signal_name}' for '{snapshot.model.name}' at {snapshot.model._path}" ) + raise error + finally: + ready_intervals = merge_intervals(intervals) + self.signal_listener.on_signal_end( + snapshot_id=snapshot.snapshot_id, + signal_name=signal_name, + signal_index=signal_idx, + signal_kwargs=signals_to_serialized_kwargs[signal_name], + intervals=intervals_to_check, + ready_intervals=ready_intervals, + error=error, + ) duration = time.perf_counter() - signal_start_ts @@ -800,7 +865,7 @@ def _check_ready_intervals( signal_name=signal_name, signal_idx=signal_idx, total_signals=len(signals.signals_to_kwargs), - ready_intervals=merge_intervals(intervals), + ready_intervals=ready_intervals, check_intervals=intervals_to_check, duration=duration, ) diff --git a/sqlmesh/utils/pydantic.py b/sqlmesh/utils/pydantic.py index 317e873aeb..ec4722ad41 100644 --- a/sqlmesh/utils/pydantic.py +++ b/sqlmesh/utils/pydantic.py @@ -60,6 +60,16 @@ def _expression_encoder(e: exp.Expression) -> str: return e.meta.get("sql") or e.sql(dialect=e.meta.get("dialect")) +def serialize_expressions(kwargs: t.Dict[str, t.Optional[exp.Expression]]) -> t.Dict[str, str]: + serialized_kwargs: t.Dict[str, str] = {} + for key, value in kwargs.items(): + if isinstance(value, exp.Expression): + serialized_kwargs[key] = _expression_encoder(value) + else: + serialized_kwargs[key] = str(value) + return serialized_kwargs + + AuditQueryTypes = t.Union[exp.Query, d.JinjaQuery] ModelQueryTypes = t.Union[exp.Query, d.JinjaQuery, d.MacroFunc] From 2115a701a6bfa0a1f6fb567e0f4dea5dc7bcab3d Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Fri, 18 Jul 2025 15:30:57 +0300 Subject: [PATCH 2/3] remove instance check --- sqlmesh/utils/pydantic.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sqlmesh/utils/pydantic.py b/sqlmesh/utils/pydantic.py index ec4722ad41..14be48e643 100644 --- a/sqlmesh/utils/pydantic.py +++ b/sqlmesh/utils/pydantic.py @@ -63,10 +63,7 @@ def _expression_encoder(e: exp.Expression) -> str: def serialize_expressions(kwargs: t.Dict[str, t.Optional[exp.Expression]]) -> t.Dict[str, str]: serialized_kwargs: t.Dict[str, str] = {} for key, value in kwargs.items(): - if isinstance(value, exp.Expression): - serialized_kwargs[key] = _expression_encoder(value) - else: - serialized_kwargs[key] = str(value) + serialized_kwargs[key] = _expression_encoder(value) if value else str(value) return serialized_kwargs From 7ae9d391467a4ece2c4269cc6b6ed021cee90fb7 Mon Sep 17 00:00:00 2001 From: Themis Valtinos <73662635+themisvaltinos@users.noreply.github.com> Date: Tue, 5 Aug 2025 13:35:00 +0300 Subject: [PATCH 3/3] refactor to pass snapshot and kwargs unserialized --- sqlmesh/core/scheduler.py | 31 ++++++++++--------------------- sqlmesh/utils/pydantic.py | 7 ------- 2 files changed, 10 insertions(+), 28 deletions(-) diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index d9176696d1..f8a171730a 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -49,7 +49,6 @@ SQLMeshError, SignalEvalError, ) -from sqlmesh.utils.pydantic import serialize_expressions if t.TYPE_CHECKING: from sqlmesh.core.context import ExecutionContext @@ -64,27 +63,26 @@ class SignalListener: def on_signal_register( self, - snapshot_id: SnapshotId, - signals: t.Dict[str, t.Dict[str, str]], + snapshots: t.List[Snapshot], ) -> None: pass def on_signal_start( self, - snapshot_id: SnapshotId, + snapshot: Snapshot, signal_name: str, signal_index: int, intervals: Intervals, - signal_kwargs: t.Dict[str, str], + signal_kwargs: t.Dict[str, t.Optional[exp.Expression]], ) -> None: pass def on_signal_end( self, - snapshot_id: SnapshotId, + snapshot: Snapshot, signal_name: str, signal_index: int, - signal_kwargs: t.Dict[str, str], + signal_kwargs: t.Dict[str, t.Optional[exp.Expression]], intervals: Intervals, ready_intervals: Intervals, error: t.Optional[Exception] = None, @@ -339,6 +337,7 @@ def batch_intervals( ) for snapshot, intervals in merged_intervals.items() } + self.signal_listener.on_signal_register(list(merged_intervals.keys())) snapshot_batches = {} all_unready_intervals: t.Dict[str, set[Interval]] = {} for snapshot_id in dag: @@ -800,16 +799,6 @@ def _check_ready_intervals( if not (signals and signals.signals_to_kwargs): return intervals - signals_to_serialized_kwargs = { - signal_name: serialize_expressions(kwargs) - for signal_name, kwargs in signals.signals_to_kwargs.items() - } - - self.signal_listener.on_signal_register( - snapshot_id=snapshot.snapshot_id, - signals=signals_to_serialized_kwargs, - ) - self.console.start_signal_progress( snapshot, self.default_catalog, @@ -823,11 +812,11 @@ def _check_ready_intervals( signal_start_ts = time.perf_counter() self.signal_listener.on_signal_start( - snapshot_id=snapshot.snapshot_id, + snapshot=snapshot, signal_name=signal_name, signal_index=signal_idx, intervals=intervals_to_check, - signal_kwargs=signals_to_serialized_kwargs[signal_name], + signal_kwargs=kwargs, ) error = None @@ -849,10 +838,10 @@ def _check_ready_intervals( finally: ready_intervals = merge_intervals(intervals) self.signal_listener.on_signal_end( - snapshot_id=snapshot.snapshot_id, + snapshot=snapshot, signal_name=signal_name, signal_index=signal_idx, - signal_kwargs=signals_to_serialized_kwargs[signal_name], + signal_kwargs=kwargs, intervals=intervals_to_check, ready_intervals=ready_intervals, error=error, diff --git a/sqlmesh/utils/pydantic.py b/sqlmesh/utils/pydantic.py index 14be48e643..317e873aeb 100644 --- a/sqlmesh/utils/pydantic.py +++ b/sqlmesh/utils/pydantic.py @@ -60,13 +60,6 @@ def _expression_encoder(e: exp.Expression) -> str: return e.meta.get("sql") or e.sql(dialect=e.meta.get("dialect")) -def serialize_expressions(kwargs: t.Dict[str, t.Optional[exp.Expression]]) -> t.Dict[str, str]: - serialized_kwargs: t.Dict[str, str] = {} - for key, value in kwargs.items(): - serialized_kwargs[key] = _expression_encoder(value) if value else str(value) - return serialized_kwargs - - AuditQueryTypes = t.Union[exp.Query, d.JinjaQuery] ModelQueryTypes = t.Union[exp.Query, d.JinjaQuery, d.MacroFunc]