Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 76 additions & 2 deletions src/sqlite3_to_mysql/transporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from .mysql_utils import (
MYSQL_BLOB_COLUMN_TYPES,
MYSQL_COLUMN_TYPES,
MYSQL_COLUMN_TYPES_WITHOUT_DEFAULT,
MYSQL_INSERT_METHOD,
MYSQL_TEXT_COLUMN_TYPES,
MYSQL_TEXT_COLUMN_TYPES_WITH_JSON,
Expand Down Expand Up @@ -109,6 +108,8 @@ def __init__(self, **kwargs: Unpack[SQLite3toMySQLParams]):

self._mysql_port = kwargs.get("mysql_port", 3306) or 3306

self._is_mariadb = False

if kwargs.get("mysql_socket") is not None:
if not os.path.exists(str(kwargs.get("mysql_socket"))):
raise FileNotFoundError("MySQL socket does not exist")
Expand Down Expand Up @@ -231,6 +232,7 @@ def __init__(self, **kwargs: Unpack[SQLite3toMySQLParams]):
raise

self._mysql_version = self._get_mysql_version()
self._is_mariadb = "-mariadb" in self._mysql_version.lower()
self._mysql_json_support = check_mysql_json_support(self._mysql_version)
self._mysql_fulltext_support = check_mysql_fulltext_support(self._mysql_version)
self._allow_expr_defaults = check_mysql_expression_defaults_support(self._mysql_version)
Expand Down Expand Up @@ -329,6 +331,69 @@ def _create_database(self) -> None:
def _valid_column_type(cls, column_type: str) -> t.Optional[t.Match[str]]:
return cls.COLUMN_PATTERN.match(column_type.strip())

@classmethod
def _base_mysql_column_type(cls, column_type: str) -> str:
stripped: str = column_type.strip()
if not stripped:
return ""
match = cls._valid_column_type(stripped)
if match:
return match.group(0).strip().upper()
return stripped.split("(", 1)[0].strip().upper()

def _column_type_supports_default(self, base_type: str, allow_expr_defaults: bool) -> bool:
normalized: str = base_type.upper()
if not normalized:
return True
if normalized == "GEOMETRY":
return False
if normalized in MYSQL_BLOB_COLUMN_TYPES:
return False
if normalized in MYSQL_TEXT_COLUMN_TYPES_WITH_JSON:
return allow_expr_defaults
return True

@staticmethod
def _parse_sql_expression(value: str) -> t.Optional[exp.Expression]:
stripped: str = value.strip()
if not stripped:
return None
for dialect in ("mysql", "sqlite"):
try:
return sqlglot.parse_one(stripped, read=dialect)
except sqlglot_errors.ParseError:
continue
return None

def _format_textual_default(
self,
default_sql: str,
allow_expr_defaults: bool,
is_mariadb: bool,
) -> str:
"""Normalise textual DEFAULT expressions and wrap for MySQL via sqlglot."""
stripped: str = default_sql.strip()
if not stripped or stripped.upper() == "NULL":
return stripped
if not allow_expr_defaults:
return stripped

expr: t.Optional[exp.Expression] = self._parse_sql_expression(stripped)
if expr is None:
if is_mariadb or stripped.startswith("("):
return stripped
return f"({stripped})"

formatted: str = expr.sql(dialect="mysql")
if is_mariadb:
return formatted

if isinstance(expr, exp.Paren):
return formatted

wrapped = exp.Paren(this=expr.copy())
return wrapped.sql(dialect="mysql")

def _translate_type_from_sqlite_to_mysql(self, column_type: str) -> str:
normalized: t.Optional[str] = self._normalize_sqlite_column_type(column_type)
if normalized and normalized.upper() != column_type.upper():
Expand Down Expand Up @@ -804,16 +869,25 @@ def _create_table(self, table_name: str, transfer_rowid: bool = False, skip_defa
column["pk"] > 0 and column_type.startswith(("INT", "BIGINT")) and not compound_primary_key
)

allow_expr_defaults: bool = getattr(self, "_allow_expr_defaults", False)
is_mariadb: bool = getattr(self, "_is_mariadb", False)
base_type: str = self._base_mysql_column_type(column_type)

# Build DEFAULT clause safely (preserve falsy defaults like 0/'')
default_clause: str = ""
if (
not skip_default
and column["dflt_value"] is not None
and column_type not in MYSQL_COLUMN_TYPES_WITHOUT_DEFAULT
and self._column_type_supports_default(base_type, allow_expr_defaults)
and not auto_increment
):
td: str = self._translate_default_for_mysql(column_type, str(column["dflt_value"]))
if td != "":
stripped_td: str = td.strip()
if base_type in MYSQL_TEXT_COLUMN_TYPES_WITH_JSON and stripped_td.upper() != "NULL":
td = self._format_textual_default(stripped_td, allow_expr_defaults, is_mariadb)
else:
td = stripped_td
default_clause = "DEFAULT " + td
sql += " `{name}` {type} {notnull} {default} {auto_increment}, ".format(
name=mysql_safe_name,
Expand Down
1 change: 1 addition & 0 deletions src/sqlite3_to_mysql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class SQLite3toMySQLAttributes:
_mysql: MySQLConnection
_mysql_cur: MySQLCursor
_mysql_version: str
_is_mariadb: bool
_mysql_json_support: bool
_mysql_fulltext_support: bool
_allow_expr_defaults: bool
Expand Down
195 changes: 195 additions & 0 deletions tests/unit/sqlite3_to_mysql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,201 @@ def test_create_table_invalid_default_retries_without_defaults(self, mocker: Moc
assert "DEFAULT CURRENT_TIMESTAMP" not in retry_sql
instance._logger.warning.assert_called_once()

def test_create_table_text_default_mariadb(self, mocker: MockerFixture) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
instance._sqlite_table_xinfo_support = False
instance._sqlite_quote_ident = lambda name: name.replace('"', '""')
instance._mysql_charset = "utf8mb4"
instance._mysql_collation = "utf8mb4_unicode_ci"
instance._logger = mocker.MagicMock()
instance._allow_expr_defaults = True
instance._is_mariadb = True

rows = [
{"name": "body", "type": "TEXT", "notnull": 1, "dflt_value": "'[]'", "pk": 0},
]

sqlite_cursor = mocker.MagicMock()
sqlite_cursor.fetchall.return_value = rows
instance._sqlite_cur = sqlite_cursor

instance._translate_type_from_sqlite_to_mysql = mocker.MagicMock(return_value="TEXT")

mysql_cursor = mocker.MagicMock()
instance._mysql_cur = mysql_cursor
instance._mysql = mocker.MagicMock()

instance._create_table("demo")

executed_sql = mysql_cursor.execute.call_args[0][0]
assert "DEFAULT '[]'" in executed_sql
assert "DEFAULT ('[]')" not in executed_sql

def test_create_table_text_default_mysql_expression(self, mocker: MockerFixture) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
instance._sqlite_table_xinfo_support = False
instance._sqlite_quote_ident = lambda name: name.replace('"', '""')
instance._mysql_charset = "utf8mb4"
instance._mysql_collation = "utf8mb4_unicode_ci"
instance._logger = mocker.MagicMock()
instance._allow_expr_defaults = True
instance._is_mariadb = False

rows = [
{"name": "body", "type": "TEXT", "notnull": 1, "dflt_value": "'[]'", "pk": 0},
]

sqlite_cursor = mocker.MagicMock()
sqlite_cursor.fetchall.return_value = rows
instance._sqlite_cur = sqlite_cursor

instance._translate_type_from_sqlite_to_mysql = mocker.MagicMock(return_value="TEXT")

mysql_cursor = mocker.MagicMock()
instance._mysql_cur = mysql_cursor
instance._mysql = mocker.MagicMock()

instance._create_table("demo")

executed_sql = mysql_cursor.execute.call_args[0][0]
assert "DEFAULT ('[]')" in executed_sql

def test_create_table_text_default_mysql_function_expression(self, mocker: MockerFixture) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
instance._sqlite_table_xinfo_support = False
instance._sqlite_quote_ident = lambda name: name.replace('"', '""')
instance._mysql_charset = "utf8mb4"
instance._mysql_collation = "utf8mb4_unicode_ci"
instance._logger = mocker.MagicMock()
instance._allow_expr_defaults = True
instance._is_mariadb = False

rows = [
{"name": "body", "type": "TEXT", "notnull": 1, "dflt_value": "json_array()", "pk": 0},
]

sqlite_cursor = mocker.MagicMock()
sqlite_cursor.fetchall.return_value = rows
instance._sqlite_cur = sqlite_cursor

instance._translate_type_from_sqlite_to_mysql = mocker.MagicMock(return_value="TEXT")
instance._translate_default_for_mysql = mocker.MagicMock(return_value="JSON_ARRAY()")

mysql_cursor = mocker.MagicMock()
instance._mysql_cur = mysql_cursor
instance._mysql = mocker.MagicMock()

instance._create_table("demo")

executed_sql = mysql_cursor.execute.call_args[0][0]
assert "DEFAULT (JSON_ARRAY())" in executed_sql

def test_parse_sql_expression_falls_back_to_sqlite(self, mocker: MockerFixture) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
sqlite_expr = exp.Literal.string("ok")
parse_mock = mocker.patch(
"sqlite3_to_mysql.transporter.sqlglot.parse_one",
side_effect=[sqlglot_errors.ParseError("mysql"), sqlite_expr],
)

result = instance._parse_sql_expression("value")

assert result is sqlite_expr
assert parse_mock.call_args_list[0].kwargs["read"] == "mysql"
assert parse_mock.call_args_list[1].kwargs["read"] == "sqlite"

def test_parse_sql_expression_returns_none_when_unparseable(self, mocker: MockerFixture) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
parse_mock = mocker.patch(
"sqlite3_to_mysql.transporter.sqlglot.parse_one",
side_effect=[
sqlglot_errors.ParseError("mysql"),
sqlglot_errors.ParseError("sqlite"),
],
)

result = instance._parse_sql_expression("value")

assert result is None
assert parse_mock.call_count == 2

def test_format_textual_default_wraps_when_unparseable_mysql(self, mocker: MockerFixture) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
mocker.patch.object(instance, "_parse_sql_expression", return_value=None)

result = instance._format_textual_default("raw_json()", True, False)

assert result == "(raw_json())"

def test_format_textual_default_mariadb_uses_literal_output(self, mocker: MockerFixture) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
literal_expr = exp.Literal.string("[]")
mocker.patch.object(instance, "_parse_sql_expression", return_value=literal_expr)

result = instance._format_textual_default("'[]'", True, True)

assert result == "'[]'"

def test_format_textual_default_preserves_existing_parens(self, mocker: MockerFixture) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
paren_expr = exp.Paren(this=exp.Literal.string("[]"))
mocker.patch.object(instance, "_parse_sql_expression", return_value=paren_expr)

result = instance._format_textual_default("('[]')", True, False)

assert result == "('[]')"

def test_format_textual_default_respects_disabled_expression_defaults(self) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)

result = instance._format_textual_default("'[]'", False, False)

assert result == "'[]'"

def test_base_mysql_column_type_handles_whitespace_and_unknown(self) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)

assert instance._base_mysql_column_type(" TEXT(255) ") == "TEXT"
assert instance._base_mysql_column_type("custom_type") == "CUSTOM_TYPE"
assert instance._base_mysql_column_type("(TEXT)") == ""
assert instance._base_mysql_column_type(" ") == ""

def test_column_type_supports_default_branches(self) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)

assert not instance._column_type_supports_default("GEOMETRY", True)
assert not instance._column_type_supports_default("BLOB", True)
assert not instance._column_type_supports_default("TEXT", False)
assert instance._column_type_supports_default("", True)
assert instance._column_type_supports_default("VARCHAR", False)

def test_parse_sql_expression_returns_none_for_blank(self) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)

assert instance._parse_sql_expression(" ") is None

def test_format_textual_default_handles_blank_and_null(self) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)

assert instance._format_textual_default(" ", True, False) == ""
assert instance._format_textual_default("NULL", True, False) == "NULL"

def test_format_textual_default_mariadb_preserves_unparseable(self, mocker: MockerFixture) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
mocker.patch.object(instance, "_parse_sql_expression", return_value=None)

result = instance._format_textual_default("json_array()", True, True)

assert result == "json_array()"

def test_format_textual_default_preserves_parenthesised_unparseable(self, mocker: MockerFixture) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
mocker.patch.object(instance, "_parse_sql_expression", return_value=None)

result = instance._format_textual_default("(select 1)", True, False)

assert result == "(select 1)"

def test_truncate_table_executes_when_table_exists(self, mocker: MockerFixture) -> None:
instance = SQLite3toMySQL.__new__(SQLite3toMySQL)
cursor = mocker.MagicMock()
Expand Down