Skip to content

Commit 205bcff

Browse files
Feat: Add signal listener and methods in scheduler
1 parent 9a954d7 commit 205bcff

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

sqlmesh/core/scheduler.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
SQLMeshError,
5050
SignalEvalError,
5151
)
52+
from sqlmesh.utils.pydantic import serialize_expressions
5253

5354
if t.TYPE_CHECKING:
5455
from sqlmesh.core.context import ExecutionContext
@@ -60,6 +61,37 @@
6061
SchedulingUnit = t.Tuple[str, t.Tuple[Interval, int]]
6162

6263

64+
class SignalListener:
65+
def on_signal_register(
66+
self,
67+
snapshot_id: SnapshotId,
68+
signals: t.Dict[str, t.Dict[str, str]],
69+
) -> None:
70+
pass
71+
72+
def on_signal_start(
73+
self,
74+
snapshot_id: SnapshotId,
75+
signal_name: str,
76+
signal_index: int,
77+
intervals: Intervals,
78+
signal_kwargs: t.Dict[str, str],
79+
) -> None:
80+
pass
81+
82+
def on_signal_end(
83+
self,
84+
snapshot_id: SnapshotId,
85+
signal_name: str,
86+
signal_index: int,
87+
signal_kwargs: t.Dict[str, str],
88+
intervals: Intervals,
89+
ready_intervals: Intervals,
90+
error: t.Optional[Exception] = None,
91+
) -> None:
92+
pass
93+
94+
6395
class Scheduler:
6496
"""Schedules and manages the evaluation of snapshots.
6597
@@ -86,6 +118,7 @@ def __init__(
86118
max_workers: int = 1,
87119
console: t.Optional[Console] = None,
88120
notification_target_manager: t.Optional[NotificationTargetManager] = None,
121+
signal_listener: t.Optional[SignalListener] = None,
89122
):
90123
self.state_sync = state_sync
91124
self.snapshots = {s.snapshot_id: s for s in snapshots}
@@ -98,6 +131,7 @@ def __init__(
98131
self.notification_target_manager = (
99132
notification_target_manager or NotificationTargetManager()
100133
)
134+
self.signal_listener = signal_listener or SignalListener()
101135

102136
def merged_missing_intervals(
103137
self,
@@ -766,6 +800,16 @@ def _check_ready_intervals(
766800
if not (signals and signals.signals_to_kwargs):
767801
return intervals
768802

803+
signals_to_serialized_kwargs = {
804+
signal_name: serialize_expressions(kwargs)
805+
for signal_name, kwargs in signals.signals_to_kwargs.items()
806+
}
807+
808+
self.signal_listener.on_signal_register(
809+
snapshot_id=snapshot.snapshot_id,
810+
signals=signals_to_serialized_kwargs,
811+
)
812+
769813
self.console.start_signal_progress(
770814
snapshot,
771815
self.default_catalog,
@@ -778,6 +822,15 @@ def _check_ready_intervals(
778822

779823
signal_start_ts = time.perf_counter()
780824

825+
self.signal_listener.on_signal_start(
826+
snapshot_id=snapshot.snapshot_id,
827+
signal_name=signal_name,
828+
signal_index=signal_idx,
829+
intervals=intervals_to_check,
830+
signal_kwargs=signals_to_serialized_kwargs[signal_name],
831+
)
832+
833+
error = None
781834
try:
782835
intervals = check_ready_intervals(
783836
signals.prepared_python_env[signal_name],
@@ -789,9 +842,21 @@ def _check_ready_intervals(
789842
kwargs=kwargs,
790843
)
791844
except SQLMeshError as e:
792-
raise SignalEvalError(
845+
error = SignalEvalError(
793846
f"{e} '{signal_name}' for '{snapshot.model.name}' at {snapshot.model._path}"
794847
)
848+
raise error
849+
finally:
850+
ready_intervals = merge_intervals(intervals)
851+
self.signal_listener.on_signal_end(
852+
snapshot_id=snapshot.snapshot_id,
853+
signal_name=signal_name,
854+
signal_index=signal_idx,
855+
signal_kwargs=signals_to_serialized_kwargs[signal_name],
856+
intervals=intervals_to_check,
857+
ready_intervals=ready_intervals,
858+
error=error,
859+
)
795860

796861
duration = time.perf_counter() - signal_start_ts
797862

@@ -800,7 +865,7 @@ def _check_ready_intervals(
800865
signal_name=signal_name,
801866
signal_idx=signal_idx,
802867
total_signals=len(signals.signals_to_kwargs),
803-
ready_intervals=merge_intervals(intervals),
868+
ready_intervals=ready_intervals,
804869
check_intervals=intervals_to_check,
805870
duration=duration,
806871
)

sqlmesh/utils/pydantic.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ def _expression_encoder(e: exp.Expression) -> str:
6060
return e.meta.get("sql") or e.sql(dialect=e.meta.get("dialect"))
6161

6262

63+
def serialize_expressions(kwargs: t.Dict[str, t.Optional[exp.Expression]]) -> t.Dict[str, str]:
64+
serialized_kwargs: t.Dict[str, str] = {}
65+
for key, value in kwargs.items():
66+
if isinstance(value, exp.Expression):
67+
serialized_kwargs[key] = _expression_encoder(value)
68+
else:
69+
serialized_kwargs[key] = str(value)
70+
return serialized_kwargs
71+
72+
6373
AuditQueryTypes = t.Union[exp.Query, d.JinjaQuery]
6474
ModelQueryTypes = t.Union[exp.Query, d.JinjaQuery, d.MacroFunc]
6575

0 commit comments

Comments
 (0)