37
37
SnapshotCreationFailedError ,
38
38
SnapshotNameVersion ,
39
39
)
40
+ from sqlmesh .core .snapshot .definition import SnapshotEvaluationTriggers
40
41
from sqlmesh .utils import to_snake_case
41
42
from sqlmesh .core .state_sync import StateSync
42
43
from sqlmesh .utils import CorrelationId
@@ -83,6 +84,7 @@ def __init__(
83
84
self .default_catalog = default_catalog
84
85
self .console = console or get_console ()
85
86
self ._circuit_breaker : t .Optional [t .Callable [[], bool ]] = None
87
+ self ._restatement_triggers : t .Dict [SnapshotId , t .List [SnapshotId ]] = {}
86
88
87
89
def evaluate (
88
90
self ,
@@ -234,6 +236,27 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
234
236
self .console .log_success ("SKIP: No model batches to execute" )
235
237
return
236
238
239
+ directly_modified_triggers : t .Dict [SnapshotId , t .List [SnapshotId ]] = {}
240
+ for parent , children in plan .indirectly_modified_snapshots .items ():
241
+ parent_id = stage .all_snapshots [parent ].snapshot_id
242
+ directly_modified_triggers [parent_id ] = directly_modified_triggers .get (
243
+ parent_id , []
244
+ ) + [parent_id ]
245
+ for child in children :
246
+ directly_modified_triggers [child ] = directly_modified_triggers .get (child , []) + [
247
+ parent_id
248
+ ]
249
+ directly_modified_triggers = {
250
+ k : list (dict .fromkeys (v )) for k , v in directly_modified_triggers .items ()
251
+ }
252
+ snapshot_evaluation_triggers = {
253
+ s_id : SnapshotEvaluationTriggers (
254
+ directly_modified_triggers = directly_modified_triggers .get (s_id , []),
255
+ restatement_triggers = self ._restatement_triggers .get (s_id , []),
256
+ )
257
+ for s_id in [s .snapshot_id for s in stage .all_snapshots .values ()]
258
+ }
259
+
237
260
scheduler = self .create_scheduler (stage .all_snapshots .values (), self .snapshot_evaluator )
238
261
# Convert model name restatements to snapshot ID restatements
239
262
restatements_by_snapshot_id = {
@@ -249,6 +272,7 @@ def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePla
249
272
start = plan .start ,
250
273
end = plan .end ,
251
274
restatements = restatements_by_snapshot_id ,
275
+ snapshot_evaluation_triggers = snapshot_evaluation_triggers ,
252
276
)
253
277
if errors :
254
278
raise PlanError ("Plan application failed." )
@@ -286,13 +310,14 @@ def visit_restatement_stage(
286
310
# by forcing dev environments to re-run intervals that changed in prod
287
311
#
288
312
# Without this rule, its possible that promoting a dev table to prod will introduce old data to prod
289
- snapshot_intervals_to_restate . update (
313
+ restatement_intervals_all_environments , self . _restatement_triggers = (
290
314
self ._restatement_intervals_across_all_environments (
291
315
prod_restatements = plan .restatements ,
292
316
disable_restatement_models = plan .disabled_restatement_models ,
293
317
loaded_snapshots = {s .snapshot_id : s for s in stage .all_snapshots .values ()},
294
318
)
295
319
)
320
+ snapshot_intervals_to_restate .update (restatement_intervals_all_environments )
296
321
297
322
self .state_sync .remove_intervals (
298
323
snapshot_intervals = list (snapshot_intervals_to_restate ),
@@ -415,7 +440,9 @@ def _restatement_intervals_across_all_environments(
415
440
prod_restatements : t .Dict [str , Interval ],
416
441
disable_restatement_models : t .Set [str ],
417
442
loaded_snapshots : t .Dict [SnapshotId , Snapshot ],
418
- ) -> t .Set [t .Tuple [SnapshotTableInfo , Interval ]]:
443
+ ) -> t .Tuple [
444
+ t .Set [t .Tuple [SnapshotTableInfo , Interval ]], t .Dict [SnapshotId , t .List [SnapshotId ]]
445
+ ]:
419
446
"""
420
447
Given a map of snapshot names + intervals to restate in prod:
421
448
- Look up matching snapshots across all environments (match based on name - regardless of version)
@@ -426,14 +453,14 @@ def _restatement_intervals_across_all_environments(
426
453
run in those environments causes the intervals to be repopulated
427
454
"""
428
455
if not prod_restatements :
429
- return set ()
456
+ return set (), {}
430
457
431
458
prod_name_versions : t .Set [SnapshotNameVersion ] = {
432
459
s .name_version for s in loaded_snapshots .values ()
433
460
}
434
461
435
462
snapshots_to_restate : t .Dict [SnapshotId , t .Tuple [SnapshotTableInfo , Interval ]] = {}
436
-
463
+ restatement_downstream_ids : t . Dict [ SnapshotId , t . List [ SnapshotId ]] = {}
437
464
for env_summary in self .state_sync .get_environments_summary ():
438
465
# Fetch the full environment object one at a time to avoid loading all environments into memory at once
439
466
env = self .state_sync .get_environment (env_summary .name )
@@ -450,10 +477,17 @@ def _restatement_intervals_across_all_environments(
450
477
for restatement , intervals in prod_restatements .items ():
451
478
if restatement not in keyed_snapshots :
452
479
continue
480
+
481
+ downstream = env_dag .downstream (restatement )
482
+ if not env .is_dev and restatement not in disable_restatement_models :
483
+ restatement_downstream_ids [keyed_snapshots [restatement ].snapshot_id ] = [
484
+ keyed_snapshots [name ].snapshot_id
485
+ for name in downstream
486
+ if name not in disable_restatement_models
487
+ ]
488
+
453
489
affected_snapshot_names = [
454
- x
455
- for x in ([restatement ] + env_dag .downstream (restatement ))
456
- if x not in disable_restatement_models
490
+ x for x in ([restatement ] + downstream ) if x not in disable_restatement_models
457
491
]
458
492
snapshots_to_restate .update (
459
493
{
@@ -464,6 +498,14 @@ def _restatement_intervals_across_all_environments(
464
498
}
465
499
)
466
500
501
+ restatement_triggers : t .Dict [SnapshotId , t .List [SnapshotId ]] = {
502
+ id : [id ] for id in restatement_downstream_ids
503
+ }
504
+ for parent , children in restatement_downstream_ids .items ():
505
+ for child in children :
506
+ restatement_triggers [child ] = restatement_triggers .get (child , []) + [parent ]
507
+ restatement_triggers = {k : list (dict .fromkeys (v )) for k , v in restatement_triggers .items ()}
508
+
467
509
# for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to
468
510
# include the whole time range for that snapshot. This requires a call to state to load the full snapshot record,
469
511
# so we only do it if necessary
@@ -499,7 +541,7 @@ def _restatement_intervals_across_all_environments(
499
541
)
500
542
snapshots_to_restate [full_snapshot_id ] = (full_snapshot .table_info , new_intervals )
501
543
502
- return set (snapshots_to_restate .values ())
544
+ return set (snapshots_to_restate .values ()), restatement_triggers
503
545
504
546
def _update_intervals_for_new_snapshots (self , snapshots : t .Collection [Snapshot ]) -> None :
505
547
snapshots_intervals : t .List [SnapshotIntervals ] = []
0 commit comments