Skip to content

Commit 78d9ac5

Browse files
committed
Feat: introduce migration pre-checks
1 parent e430f50 commit 78d9ac5

File tree

8 files changed

+125
-11
lines changed

8 files changed

+125
-11
lines changed

sqlmesh/cli/main.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -910,12 +910,17 @@ def ui(ctx: click.Context, host: str, port: int, mode: str) -> None:
910910

911911

912912
@cli.command("migrate")
913+
@click.option(
914+
"--pre-check",
915+
is_flag=True,
916+
help="Run pre-checks and display warnings without performing migration",
917+
)
913918
@click.pass_context
914919
@error_handler
915920
@cli_analytics
916-
def migrate(ctx: click.Context) -> None:
921+
def migrate(ctx: click.Context, pre_check: bool) -> None:
917922
"""Migrate SQLMesh to the current running version."""
918-
ctx.obj.migrate()
923+
ctx.obj.migrate(pre_check_only=pre_check)
919924

920925

921926
@cli.command("rollback")

sqlmesh/core/console.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,17 @@ def update_env_migration_progress(self, num_tasks: int) -> None:
497497
def stop_env_migration_progress(self, success: bool = True) -> None:
498498
"""Stop the environment migration progress."""
499499

500+
@abc.abstractmethod
501+
def log_pre_check_warnings(
502+
self,
503+
pre_check_warnings: t.List[t.Tuple[str, t.List[str]]],
504+
pre_check_only: bool,
505+
) -> bool:
506+
"""
507+
Log warnings emitted by pre-check scripts and ask user whether they'd like to
508+
proceed with the migration (true) or not (false).
509+
"""
510+
500511
@abc.abstractmethod
501512
def plan(
502513
self,
@@ -662,6 +673,13 @@ def update_env_migration_progress(self, num_tasks: int) -> None:
662673
def stop_env_migration_progress(self, success: bool = True) -> None:
663674
pass
664675

676+
def log_pre_check_warnings(
677+
self,
678+
pre_check_warnings: t.List[t.Tuple[str, t.List[str]]],
679+
pre_check_only: bool,
680+
) -> bool:
681+
return True
682+
665683
def start_state_export(
666684
self,
667685
output_file: Path,
@@ -1472,6 +1490,33 @@ def stop_env_migration_progress(self, success: bool = True) -> None:
14721490
if success:
14731491
self.log_success("Environments migrated successfully")
14741492

1493+
def log_pre_check_warnings(
1494+
self,
1495+
pre_check_warnings: t.List[t.Tuple[str, t.List[str]]],
1496+
pre_check_only: bool,
1497+
) -> bool:
1498+
if pre_check_warnings:
1499+
for pre_check, warnings in pre_check_warnings:
1500+
tree = Tree(f"[bold]Pre-migration warnings for {pre_check}[/bold]")
1501+
for warning in warnings:
1502+
tree.add(f"[yellow]{warning}[/yellow]")
1503+
1504+
self._print(tree)
1505+
1506+
if pre_check_only:
1507+
return False
1508+
1509+
should_continue = self._confirm("\nDo you want to proceed with the migration?")
1510+
if not should_continue:
1511+
self.log_status_update("Migration cancelled.")
1512+
1513+
return should_continue
1514+
if pre_check_only:
1515+
self.log_status_update("No pre-migration warnings detected.")
1516+
return False
1517+
1518+
return True
1519+
14751520
def start_state_export(
14761521
self,
14771522
output_file: Path,

sqlmesh/core/context.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2319,24 +2319,33 @@ def check_intervals(
23192319
return results
23202320

23212321
@python_api_analytics
2322-
def migrate(self) -> None:
2322+
def migrate(self, pre_check_only: bool = False) -> None:
23232323
"""Migrates SQLMesh to the current running version.
23242324
23252325
Please contact your SQLMesh administrator before doing this.
2326+
2327+
Args:
2328+
pre_check_only: If True, only run pre-checks without performing the migration.
23262329
"""
2327-
self.notification_target_manager.notify(NotificationEvent.MIGRATION_START)
2330+
if not pre_check_only:
2331+
self.notification_target_manager.notify(NotificationEvent.MIGRATION_START)
2332+
23282333
self._load_materializations()
23292334
try:
23302335
self._new_state_sync().migrate(
23312336
default_catalog=self.default_catalog,
23322337
promoted_snapshots_only=self.config.migration.promoted_snapshots_only,
2338+
pre_check_only=pre_check_only,
23332339
)
23342340
except Exception as e:
2335-
self.notification_target_manager.notify(
2336-
NotificationEvent.MIGRATION_FAILURE, traceback.format_exc()
2337-
)
2341+
if not pre_check_only:
2342+
self.notification_target_manager.notify(
2343+
NotificationEvent.MIGRATION_FAILURE, traceback.format_exc()
2344+
)
23382345
raise e
2339-
self.notification_target_manager.notify(NotificationEvent.MIGRATION_END)
2346+
2347+
if not pre_check_only:
2348+
self.notification_target_manager.notify(NotificationEvent.MIGRATION_END)
23402349

23412350
@python_api_analytics
23422351
def rollback(self) -> None:

sqlmesh/core/state_sync/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from sqlglot import __version__ as SQLGLOT_VERSION
1010

11-
from sqlmesh import migrations
11+
from sqlmesh import migrations, pre_checks
1212
from sqlmesh.core.environment import (
1313
Environment,
1414
EnvironmentNamingInfo,
@@ -64,6 +64,10 @@ def _schema_version_validator(cls, v: t.Any) -> int:
6464
importlib.import_module(f"sqlmesh.migrations.{migration}")
6565
for migration in sorted(info.name for info in pkgutil.iter_modules(migrations.__path__))
6666
]
67+
PRE_CHECKS = {
68+
pre_check: importlib.import_module(f"sqlmesh.pre_checks.{pre_check}")
69+
for pre_check in sorted(info.name for info in pkgutil.iter_modules(pre_checks.__path__))
70+
}
6771
SCHEMA_VERSION: int = len(MIGRATIONS)
6872

6973

@@ -456,6 +460,7 @@ def migrate(
456460
default_catalog: t.Optional[str],
457461
skip_backup: bool = False,
458462
promoted_snapshots_only: bool = True,
463+
pre_check_only: bool = False,
459464
) -> None:
460465
"""Migrate the state sync to the latest SQLMesh / SQLGlot version."""
461466

sqlmesh/core/state_sync/db/facade.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,13 +447,15 @@ def migrate(
447447
default_catalog: t.Optional[str],
448448
skip_backup: bool = False,
449449
promoted_snapshots_only: bool = True,
450+
pre_check_only: bool = False,
450451
) -> None:
451452
"""Migrate the state sync to the latest SQLMesh / SQLGlot version."""
452453
self.migrator.migrate(
453454
self,
454455
default_catalog,
455456
skip_backup=skip_backup,
456457
promoted_snapshots_only=promoted_snapshots_only,
458+
pre_check_only=pre_check_only,
457459
)
458460

459461
@transactional()

sqlmesh/core/state_sync/db/migrator.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from sqlmesh.core.state_sync.base import (
2929
MIGRATIONS,
30+
PRE_CHECKS,
3031
)
3132
from sqlmesh.core.state_sync.base import StateSync
3233
from sqlmesh.core.state_sync.db.environment import EnvironmentState
@@ -90,8 +91,22 @@ def migrate(
9091
default_catalog: t.Optional[str],
9192
skip_backup: bool = False,
9293
promoted_snapshots_only: bool = True,
94+
pre_check_only: bool = False,
9395
) -> None:
94-
"""Migrate the state sync to the latest SQLMesh / SQLGlot version."""
96+
"""Migrate the state sync to the latest SQLMesh / SQLGlot version.
97+
98+
Args:
99+
state_sync: The state sync instance.
100+
default_catalog: The default catalog.
101+
skip_backup: Whether to skip backing up state tables.
102+
promoted_snapshots_only: Whether to migrate only promoted snapshots.
103+
pre_check_only: If True, only run pre-checks without performing migration.
104+
"""
105+
pre_check_warnings = self.run_pre_checks(state_sync)
106+
should_migrate = self.console.log_pre_check_warnings(pre_check_warnings, pre_check_only)
107+
if not should_migrate:
108+
return
109+
95110
versions = self.version_state.get_versions()
96111
migration_start_ts = time.perf_counter()
97112

@@ -153,6 +168,33 @@ def rollback(self) -> None:
153168

154169
logger.info("Migration rollback successful.")
155170

171+
def run_pre_checks(self, state_sync: StateSync) -> t.List[t.Tuple[str, t.List[str]]]:
172+
"""Run pre-checks for migrations between specified versions.
173+
174+
Args:
175+
state_sync: The state sync instance.
176+
177+
Returns:
178+
A list of pairs comprising the executed pre-checks and the corresponding warnings.
179+
"""
180+
# Get the range of the migrations that would be applied
181+
from_version = self.version_state.get_versions().schema_version
182+
to_version = len(MIGRATIONS)
183+
184+
pre_check_warnings = []
185+
for i in range(from_version, to_version):
186+
# Assumption: pre-check and migration names match
187+
pre_check_name = MIGRATIONS[i].__name__.split(".")[-1]
188+
pre_check_module = PRE_CHECKS.get(pre_check_name)
189+
190+
if callable(pre_check := getattr(pre_check_module, "pre_check", None)):
191+
logger.info(f"Running pre-check for {pre_check_name}")
192+
warnings = pre_check(state_sync)
193+
if warnings:
194+
pre_check_warnings.append((pre_check_name, warnings))
195+
196+
return pre_check_warnings
197+
156198
def _apply_migrations(
157199
self,
158200
state_sync: StateSync,

sqlmesh/magics.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,11 +709,17 @@ def dag(self, context: Context, line: str) -> None:
709709
self.display(dag)
710710

711711
@magic_arguments()
712+
@argument(
713+
"--pre-check",
714+
action="store_true",
715+
help="Run pre-checks and display warnings without performing migration",
716+
)
712717
@line_magic
713718
@pass_sqlmesh_context
714719
def migrate(self, context: Context, line: str) -> None:
715720
"""Migrate SQLMesh to the current running version."""
716-
context.migrate()
721+
args = parse_argstring(self.migrate, line)
722+
context.migrate(pre_check_only=args.pre_check)
717723
context.console.log_success("Migration complete")
718724

719725
@magic_arguments()

sqlmesh/pre_checks/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)