|
2 | 2 | import logging
|
3 | 3 | import re
|
4 | 4 | import typing as t
|
| 5 | +from types import ModuleType |
5 | 6 | from unittest.mock import call, patch
|
6 | 7 |
|
7 | 8 | import duckdb # noqa: TID253
|
|
11 | 12 | from pytest_mock.plugin import MockerFixture
|
12 | 13 | from sqlglot import exp
|
13 | 14 |
|
| 15 | +from sqlmesh.cli.project_init import init_example_project |
14 | 16 | from sqlmesh.core import constants as c
|
15 |
| -from sqlmesh.core.config import EnvironmentSuffixTarget |
| 17 | +from sqlmesh.core.config import ( |
| 18 | + Config, |
| 19 | + DuckDBConnectionConfig, |
| 20 | + EnvironmentSuffixTarget, |
| 21 | + GatewayConfig, |
| 22 | + ModelDefaultsConfig, |
| 23 | +) |
| 24 | +from sqlmesh.core.context import Context |
16 | 25 | from sqlmesh.core.dialect import parse_one, schema_
|
17 | 26 | from sqlmesh.core.engine_adapter import create_engine_adapter
|
18 | 27 | from sqlmesh.core.environment import Environment, EnvironmentStatements
|
|
48 | 57 | )
|
49 | 58 | from sqlmesh.utils.date import now_timestamp, to_datetime, to_timestamp
|
50 | 59 | from sqlmesh.utils.errors import SQLMeshError
|
| 60 | +from tests.utils.test_helpers import use_terminal_console |
51 | 61 |
|
52 | 62 | pytestmark = pytest.mark.slow
|
53 | 63 |
|
@@ -3628,3 +3638,81 @@ def test_update_environment_statements(state_sync: EngineAdapterStateSync):
|
3628 | 3638 | "@grant_schema_usage()",
|
3629 | 3639 | "@grant_select_privileges()",
|
3630 | 3640 | ]
|
| 3641 | + |
| 3642 | + |
| 3643 | +@use_terminal_console |
| 3644 | +def test_pre_checks(tmp_path, mocker): |
| 3645 | + init_example_project(tmp_path, engine_type="duckdb") |
| 3646 | + |
| 3647 | + db_path = str(tmp_path / "db.db") |
| 3648 | + config = Config( |
| 3649 | + gateways={"main": GatewayConfig(connection=DuckDBConnectionConfig(database=db_path))}, |
| 3650 | + model_defaults=ModelDefaultsConfig(dialect="duckdb"), |
| 3651 | + ) |
| 3652 | + context = Context(paths=tmp_path, config=config) |
| 3653 | + context.plan(auto_apply=True, no_prompts=True) |
| 3654 | + |
| 3655 | + def mock_migrate(state_sync, **kwargs): |
| 3656 | + pass |
| 3657 | + |
| 3658 | + def mock_pre_check_with_warnings(state_sync): |
| 3659 | + return [ |
| 3660 | + "Warning: This migration will break compatibility with older versions", |
| 3661 | + "Warning: You must update all model configurations before applying this migration", |
| 3662 | + "Warning: Existing snapshots will need to be rebuilt", |
| 3663 | + ] |
| 3664 | + |
| 3665 | + def mock_pre_check_without_warnings(state_sync): |
| 3666 | + return [] |
| 3667 | + |
| 3668 | + # Create a mock migration module with a pre_check function |
| 3669 | + mock_migration = ModuleType("v9999_test_pre_check") |
| 3670 | + |
| 3671 | + setattr(mock_migration, "migrate", mock_migrate) |
| 3672 | + setattr(mock_migration, "pre_check", mock_pre_check_with_warnings) |
| 3673 | + |
| 3674 | + versions_before_migrate = context.state_sync.get_versions() |
| 3675 | + |
| 3676 | + import sqlmesh.core.state_sync as state_sync |
| 3677 | + |
| 3678 | + test_migrations = state_sync.db.migrator.MIGRATIONS + [mock_migration] |
| 3679 | + |
| 3680 | + # Test 1: Pre-check warnings are properly collected and displayed, user rejects migration |
| 3681 | + with ( |
| 3682 | + patch.object(state_sync.db.migrator, "MIGRATIONS", test_migrations), |
| 3683 | + patch.object(context.console, "_confirm", return_value=False), |
| 3684 | + ): |
| 3685 | + console = context.console |
| 3686 | + log_pre_check_warnings_spy = mocker.spy(console, "log_pre_check_warnings") |
| 3687 | + |
| 3688 | + context.migrate(pre_check_only=False) |
| 3689 | + |
| 3690 | + calls = log_pre_check_warnings_spy.mock_calls |
| 3691 | + assert len(calls) == 1 |
| 3692 | + |
| 3693 | + pre_check_warnings = calls[0].args[0] |
| 3694 | + assert len(pre_check_warnings) == 1 |
| 3695 | + |
| 3696 | + assert pre_check_warnings[0][0] == "v9999_test_pre_check" |
| 3697 | + assert len(pre_check_warnings[0][1]) == 3 |
| 3698 | + assert all(warning.startswith("Warning:") for warning in pre_check_warnings[0][1]) |
| 3699 | + |
| 3700 | + assert context.state_sync.get_versions() == versions_before_migrate |
| 3701 | + |
| 3702 | + update_versions_spy = mocker.spy(state_sync.db.version.VersionState, "update_versions") |
| 3703 | + |
| 3704 | + # Test 2: User accepts migration after being notified about pre-check warnings |
| 3705 | + with ( |
| 3706 | + patch.object(state_sync.db.migrator, "MIGRATIONS", test_migrations), |
| 3707 | + patch.object(context.console, "_confirm", return_value=True), |
| 3708 | + ): |
| 3709 | + context.migrate(pre_check_only=False) |
| 3710 | + assert len(update_versions_spy.mock_calls) == 1 |
| 3711 | + |
| 3712 | + # Test 3: Pre-check without warning should automatically reuslt in a migration |
| 3713 | + setattr(mock_migration, "pre_check", mock_pre_check_without_warnings) |
| 3714 | + with patch.object(state_sync.db.migrator, "MIGRATIONS", test_migrations): |
| 3715 | + # Since the version module's SCHEMA_VERSION, etc, weren't patched, the old versions |
| 3716 | + # are still used, so the following should result in hitting the update_versions path |
| 3717 | + context.migrate(pre_check_only=False) |
| 3718 | + assert len(update_versions_spy.mock_calls) == 2 |
0 commit comments