Skip to content

Commit 7ae9d39

Browse files
refactor to pass snapshot and kwargs unserialized
1 parent 2115a70 commit 7ae9d39

File tree

2 files changed

+10
-28
lines changed

2 files changed

+10
-28
lines changed

sqlmesh/core/scheduler.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
SQLMeshError,
5050
SignalEvalError,
5151
)
52-
from sqlmesh.utils.pydantic import serialize_expressions
5352

5453
if t.TYPE_CHECKING:
5554
from sqlmesh.core.context import ExecutionContext
@@ -64,27 +63,26 @@
6463
class SignalListener:
6564
def on_signal_register(
6665
self,
67-
snapshot_id: SnapshotId,
68-
signals: t.Dict[str, t.Dict[str, str]],
66+
snapshots: t.List[Snapshot],
6967
) -> None:
7068
pass
7169

7270
def on_signal_start(
7371
self,
74-
snapshot_id: SnapshotId,
72+
snapshot: Snapshot,
7573
signal_name: str,
7674
signal_index: int,
7775
intervals: Intervals,
78-
signal_kwargs: t.Dict[str, str],
76+
signal_kwargs: t.Dict[str, t.Optional[exp.Expression]],
7977
) -> None:
8078
pass
8179

8280
def on_signal_end(
8381
self,
84-
snapshot_id: SnapshotId,
82+
snapshot: Snapshot,
8583
signal_name: str,
8684
signal_index: int,
87-
signal_kwargs: t.Dict[str, str],
85+
signal_kwargs: t.Dict[str, t.Optional[exp.Expression]],
8886
intervals: Intervals,
8987
ready_intervals: Intervals,
9088
error: t.Optional[Exception] = None,
@@ -339,6 +337,7 @@ def batch_intervals(
339337
)
340338
for snapshot, intervals in merged_intervals.items()
341339
}
340+
self.signal_listener.on_signal_register(list(merged_intervals.keys()))
342341
snapshot_batches = {}
343342
all_unready_intervals: t.Dict[str, set[Interval]] = {}
344343
for snapshot_id in dag:
@@ -800,16 +799,6 @@ def _check_ready_intervals(
800799
if not (signals and signals.signals_to_kwargs):
801800
return intervals
802801

803-
signals_to_serialized_kwargs = {
804-
signal_name: serialize_expressions(kwargs)
805-
for signal_name, kwargs in signals.signals_to_kwargs.items()
806-
}
807-
808-
self.signal_listener.on_signal_register(
809-
snapshot_id=snapshot.snapshot_id,
810-
signals=signals_to_serialized_kwargs,
811-
)
812-
813802
self.console.start_signal_progress(
814803
snapshot,
815804
self.default_catalog,
@@ -823,11 +812,11 @@ def _check_ready_intervals(
823812
signal_start_ts = time.perf_counter()
824813

825814
self.signal_listener.on_signal_start(
826-
snapshot_id=snapshot.snapshot_id,
815+
snapshot=snapshot,
827816
signal_name=signal_name,
828817
signal_index=signal_idx,
829818
intervals=intervals_to_check,
830-
signal_kwargs=signals_to_serialized_kwargs[signal_name],
819+
signal_kwargs=kwargs,
831820
)
832821

833822
error = None
@@ -849,10 +838,10 @@ def _check_ready_intervals(
849838
finally:
850839
ready_intervals = merge_intervals(intervals)
851840
self.signal_listener.on_signal_end(
852-
snapshot_id=snapshot.snapshot_id,
841+
snapshot=snapshot,
853842
signal_name=signal_name,
854843
signal_index=signal_idx,
855-
signal_kwargs=signals_to_serialized_kwargs[signal_name],
844+
signal_kwargs=kwargs,
856845
intervals=intervals_to_check,
857846
ready_intervals=ready_intervals,
858847
error=error,

sqlmesh/utils/pydantic.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,6 @@ def _expression_encoder(e: exp.Expression) -> str:
6060
return e.meta.get("sql") or e.sql(dialect=e.meta.get("dialect"))
6161

6262

63-
def serialize_expressions(kwargs: t.Dict[str, t.Optional[exp.Expression]]) -> t.Dict[str, str]:
64-
serialized_kwargs: t.Dict[str, str] = {}
65-
for key, value in kwargs.items():
66-
serialized_kwargs[key] = _expression_encoder(value) if value else str(value)
67-
return serialized_kwargs
68-
69-
7063
AuditQueryTypes = t.Union[exp.Query, d.JinjaQuery]
7164
ModelQueryTypes = t.Union[exp.Query, d.JinjaQuery, d.MacroFunc]
7265

0 commit comments

Comments
 (0)