Skip to content

Feat!: introduce migration pre-checks #5001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions sqlmesh/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,12 +917,17 @@ def ui(ctx: click.Context, host: str, port: int, mode: str) -> None:


@cli.command("migrate")
@click.option(
"--pre-check",
is_flag=True,
help="Run pre-checks and display warnings without performing migration",
)
@click.pass_context
@error_handler
@cli_analytics
def migrate(ctx: click.Context) -> None:
def migrate(ctx: click.Context, pre_check: bool) -> None:
"""Migrate SQLMesh to the current running version."""
ctx.obj.migrate()
ctx.obj.migrate(pre_check_only=pre_check)


@cli.command("rollback")
Expand Down
32 changes: 32 additions & 0 deletions sqlmesh/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,13 @@ def update_env_migration_progress(self, num_tasks: int) -> None:
def stop_env_migration_progress(self, success: bool = True) -> None:
"""Stop the environment migration progress."""

@abc.abstractmethod
def log_pre_check_warnings(self, pre_check_warnings: t.List[str], pre_check_only: bool) -> bool:
"""
Log warnings emitted by pre-checks and ask user whether they'd like to
proceed with the migration (true) or not (false).
"""

@abc.abstractmethod
def plan(
self,
Expand Down Expand Up @@ -662,6 +669,9 @@ def update_env_migration_progress(self, num_tasks: int) -> None:
def stop_env_migration_progress(self, success: bool = True) -> None:
pass

def log_pre_check_warnings(self, pre_check_warnings: t.List[str], pre_check_only: bool) -> bool:
return True

def start_state_export(
self,
output_file: Path,
Expand Down Expand Up @@ -1472,6 +1482,28 @@ def stop_env_migration_progress(self, success: bool = True) -> None:
if success:
self.log_success("Environments migrated successfully")

def log_pre_check_warnings(self, pre_check_warnings: t.List[str], pre_check_only: bool) -> bool:
if pre_check_warnings:
tree = Tree(f"[bold]Pre-migration warnings[/bold]")
for warning in pre_check_warnings:
tree.add(f"[yellow]{warning}[/yellow]")

self._print(tree)

if pre_check_only:
return False

should_continue = self._confirm("\nDo you want to proceed with the migration?")
if not should_continue:
self.log_status_update("Migration cancelled.")

return should_continue
if pre_check_only:
self.log_status_update("No pre-migration warnings detected.")
return False

return True

def start_state_export(
self,
output_file: Path,
Expand Down
21 changes: 15 additions & 6 deletions sqlmesh/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2319,24 +2319,33 @@ def check_intervals(
return results

@python_api_analytics
def migrate(self) -> None:
def migrate(self, pre_check_only: bool = False) -> None:
"""Migrates SQLMesh to the current running version.

Please contact your SQLMesh administrator before doing this.

Args:
pre_check_only: If True, only run pre-checks without performing the migration.
"""
self.notification_target_manager.notify(NotificationEvent.MIGRATION_START)
if not pre_check_only:
self.notification_target_manager.notify(NotificationEvent.MIGRATION_START)

self._load_materializations()
try:
self._new_state_sync().migrate(
default_catalog=self.default_catalog,
promoted_snapshots_only=self.config.migration.promoted_snapshots_only,
pre_check_only=pre_check_only,
)
except Exception as e:
self.notification_target_manager.notify(
NotificationEvent.MIGRATION_FAILURE, traceback.format_exc()
)
if not pre_check_only:
self.notification_target_manager.notify(
NotificationEvent.MIGRATION_FAILURE, traceback.format_exc()
)
raise e
self.notification_target_manager.notify(NotificationEvent.MIGRATION_END)

if not pre_check_only:
self.notification_target_manager.notify(NotificationEvent.MIGRATION_END)

@python_api_analytics
def rollback(self) -> None:
Expand Down
13 changes: 11 additions & 2 deletions sqlmesh/core/state_sync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Versions(PydanticModel):
schema_version: int = 0
sqlglot_version: str = "0.0.0"
sqlmesh_version: str = "0.0.0"
pre_check_version: int = 0

@property
def minor_sqlglot_version(self) -> t.Tuple[int, int]:
Expand All @@ -54,9 +55,9 @@ def minor_sqlmesh_version(self) -> t.Tuple[int, int]:
def _package_version_validator(cls, v: t.Any) -> str:
return "0.0.0" if v is None else str(v)

@field_validator("schema_version", mode="before")
@field_validator("schema_version", "pre_check_version", mode="before")
@classmethod
def _schema_version_validator(cls, v: t.Any) -> int:
def _int_version_validator(cls, v: t.Any) -> int:
return 0 if v is None else int(v)


Expand All @@ -65,6 +66,13 @@ def _schema_version_validator(cls, v: t.Any) -> int:
for migration in sorted(info.name for info in pkgutil.iter_modules(migrations.__path__))
]
SCHEMA_VERSION: int = len(MIGRATIONS)
PRE_CHECK_VERSION: int = (
max(
[idx for idx, migration in enumerate(MIGRATIONS) if hasattr(migration, "pre_check")],
default=-1,
)
+ 1
)


class PromotionResult(PydanticModel):
Expand Down Expand Up @@ -456,6 +464,7 @@ def migrate(
default_catalog: t.Optional[str],
skip_backup: bool = False,
promoted_snapshots_only: bool = True,
pre_check_only: bool = False,
) -> None:
"""Migrate the state sync to the latest SQLMesh / SQLGlot version."""

Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/core/state_sync/db/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,13 +447,15 @@ def migrate(
default_catalog: t.Optional[str],
skip_backup: bool = False,
promoted_snapshots_only: bool = True,
pre_check_only: bool = False,
) -> None:
"""Migrate the state sync to the latest SQLMesh / SQLGlot version."""
self.migrator.migrate(
self,
default_catalog,
skip_backup=skip_backup,
promoted_snapshots_only=promoted_snapshots_only,
pre_check_only=pre_check_only,
)

@transactional()
Expand Down
44 changes: 40 additions & 4 deletions sqlmesh/core/state_sync/db/migrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
from sqlmesh.core.snapshot.definition import (
_parents_from_node,
)
from sqlmesh.core.state_sync.base import (
MIGRATIONS,
)
from sqlmesh.core.state_sync.base import MIGRATIONS
from sqlmesh.core.state_sync.base import StateSync
from sqlmesh.core.state_sync.db.environment import EnvironmentState
from sqlmesh.core.state_sync.db.interval import IntervalState
Expand Down Expand Up @@ -90,8 +88,22 @@ def migrate(
default_catalog: t.Optional[str],
skip_backup: bool = False,
promoted_snapshots_only: bool = True,
pre_check_only: bool = False,
) -> None:
"""Migrate the state sync to the latest SQLMesh / SQLGlot version."""
"""Migrate the state sync to the latest SQLMesh / SQLGlot version.

Args:
state_sync: The state sync instance.
default_catalog: The default catalog.
skip_backup: Whether to skip backing up state tables.
promoted_snapshots_only: Whether to migrate only promoted snapshots.
pre_check_only: If True, only run pre-checks without performing migration.
"""
pre_check_warnings = self._run_pre_checks(state_sync)
should_migrate = self.console.log_pre_check_warnings(pre_check_warnings, pre_check_only)
if not should_migrate:
return

versions = self.version_state.get_versions()
migration_start_ts = time.perf_counter()

Expand Down Expand Up @@ -153,6 +165,30 @@ def rollback(self) -> None:

logger.info("Migration rollback successful.")

def _run_pre_checks(self, state_sync: StateSync) -> t.List[str]:
"""Run pre-checks for migrations between specified versions.

Args:
state_sync: The state sync instance.

Returns:
A list of pairs comprising the migration name containing the executed pre-checks
and the corresponding warnings.
"""
versions = self.version_state.get_versions()
migrations = MIGRATIONS[versions.schema_version :]

pre_check_warnings = []
for migration in migrations:
if callable(pre_check := getattr(migration, "pre_check", None)):
migration_name = migration.__name__.split(".")[-1]
logger.info(f"Running pre-check for {migration_name}")
warnings = pre_check(state_sync)
if warnings:
pre_check_warnings.extend(warnings)

return pre_check_warnings

def _apply_migrations(
self,
state_sync: StateSync,
Expand Down
9 changes: 8 additions & 1 deletion sqlmesh/core/state_sync/db/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
SQLMESH_VERSION,
)
from sqlmesh.core.state_sync.base import (
PRE_CHECK_VERSION,
SCHEMA_VERSION,
Versions,
)
Expand All @@ -31,13 +32,15 @@ def __init__(self, engine_adapter: EngineAdapter, schema: t.Optional[str] = None
"schema_version": exp.DataType.build("int"),
"sqlglot_version": exp.DataType.build(index_type),
"sqlmesh_version": exp.DataType.build(index_type),
"pre_check_version": exp.DataType.build("int"),
}

def update_versions(
self,
schema_version: int = SCHEMA_VERSION,
sqlglot_version: str = SQLGLOT_VERSION,
sqlmesh_version: str = SQLMESH_VERSION,
pre_check_version: int = PRE_CHECK_VERSION,
) -> None:
import pandas as pd

Expand All @@ -51,6 +54,7 @@ def update_versions(
"schema_version": schema_version,
"sqlglot_version": sqlglot_version,
"sqlmesh_version": sqlmesh_version,
"pre_check_version": pre_check_version,
}
]
),
Expand All @@ -69,5 +73,8 @@ def get_versions(self) -> Versions:
return no_version

return Versions(
schema_version=row[0], sqlglot_version=row[1], sqlmesh_version=seq_get(row, 2)
schema_version=row[0],
sqlglot_version=row[1],
sqlmesh_version=seq_get(row, 2),
pre_check_version=seq_get(row, 3),
)
8 changes: 7 additions & 1 deletion sqlmesh/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,11 +709,17 @@ def dag(self, context: Context, line: str) -> None:
self.display(dag)

@magic_arguments()
@argument(
"--pre-check",
action="store_true",
help="Run pre-checks and display warnings without performing migration",
)
@line_magic
@pass_sqlmesh_context
def migrate(self, context: Context, line: str) -> None:
"""Migrate SQLMesh to the current running version."""
context.migrate()
args = parse_argstring(self.migrate, line)
context.migrate(pre_check_only=args.pre_check)
context.console.log_success("Migration complete")

@magic_arguments()
Expand Down
24 changes: 24 additions & 0 deletions sqlmesh/migrations/v0089_add_pre_check_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Add new 'pre_check_version' column to the version state table."""

from sqlglot import exp


def migrate(state_sync, **kwargs): # type: ignore
engine_adapter = state_sync.engine_adapter
schema = state_sync.schema
versions_table = "_versions"
if schema:
versions_table = f"{schema}.{versions_table}"

alter_table_exp = exp.Alter(
this=exp.to_table(versions_table),
kind="TABLE",
actions=[
exp.ColumnDef(
this=exp.to_column("pre_check_version"),
kind=exp.DataType.build("int"),
)
],
)

engine_adapter.execute(alter_table_exp)
Loading