Skip to content

Commit 467b518

Browse files
authored
🦺 improve SQLite-to-MySQL type and default value translation logic (#149)
1 parent 43888d4 commit 467b518

File tree

2 files changed

+591
-1
lines changed

2 files changed

+591
-1
lines changed

src/sqlite3_to_mysql/transporter.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,55 @@ def _valid_column_type(cls, column_type: str) -> t.Optional[t.Match[str]]:
330330
return cls.COLUMN_PATTERN.match(column_type.strip())
331331

332332
def _translate_type_from_sqlite_to_mysql(self, column_type: str) -> str:
333+
normalized: t.Optional[str] = self._normalize_sqlite_column_type(column_type)
334+
if normalized and normalized.upper() != column_type.upper():
335+
self._logger.info("Normalised SQLite column type %r -> %r", column_type, normalized)
336+
try:
337+
return self._translate_type_from_sqlite_to_mysql_legacy(normalized)
338+
except ValueError:
339+
pass
340+
return self._translate_type_from_sqlite_to_mysql_legacy(column_type)
341+
342+
def _normalize_sqlite_column_type(self, column_type: str) -> t.Optional[str]:
343+
clean_type: str = column_type.strip()
344+
if not clean_type:
345+
return None
346+
347+
normalized_for_parse: str = clean_type.upper().replace("UNSIGNED BIG INT", "BIGINT UNSIGNED")
348+
try:
349+
expression = sqlglot.parse_one(f"SELECT CAST(NULL AS {normalized_for_parse})", read="sqlite")
350+
except sqlglot_errors.ParseError:
351+
# Retry: strip UNSIGNED to aid parsing; we'll re-attach it below if present.
352+
try:
353+
no_unsigned = re.sub(r"\bUNSIGNED\b", "", normalized_for_parse).strip()
354+
expression = sqlglot.parse_one(f"SELECT CAST(NULL AS {no_unsigned})", read="sqlite")
355+
except sqlglot_errors.ParseError:
356+
return None
357+
358+
cast: t.Optional[exp.Cast] = expression.find(exp.Cast)
359+
if not cast or not isinstance(cast.to, exp.DataType):
360+
return None
361+
362+
params: t.List[str] = []
363+
for expr_param in cast.to.expressions or []:
364+
value_expr = expr_param.this if isinstance(expr_param, exp.DataTypeParam) else expr_param
365+
if value_expr is None:
366+
continue
367+
params.append(value_expr.sql(dialect="mysql"))
368+
369+
base_match: t.Optional[t.Match[str]] = self._valid_column_type(clean_type)
370+
base = base_match.group(0).upper().strip() if base_match else clean_type.upper()
371+
372+
normalized = base
373+
if params:
374+
normalized += "(" + ",".join(param.strip("\"'") for param in params) + ")"
375+
376+
if "UNSIGNED" in clean_type.upper() and "UNSIGNED" not in normalized.upper().split():
377+
normalized = f"{normalized} UNSIGNED"
378+
379+
return normalized
380+
381+
def _translate_type_from_sqlite_to_mysql_legacy(self, column_type: str) -> str:
333382
"""This could be optimized even further, however is seems adequate."""
334383
full_column_type: str = column_type.upper()
335384
unsigned: bool = self.COLUMN_UNSIGNED_PATTERN.search(full_column_type) is not None
@@ -534,6 +583,19 @@ def _translate_default_for_mysql(self, column_type: str, default: str) -> str:
534583
return s
535584

536585
# Fallback: return stripped expression (MySQL 8.0.13+ allows expression defaults)
586+
if self._allow_expr_defaults:
587+
try:
588+
expr = sqlglot.parse_one(s, read="sqlite")
589+
except sqlglot_errors.ParseError:
590+
return s
591+
592+
expr = expr.transform(self._rewrite_sqlite_view_functions)
593+
594+
try:
595+
return expr.sql(dialect="mysql")
596+
except sqlglot_errors.SqlglotError:
597+
return s
598+
537599
return s
538600

539601
@classmethod

0 commit comments

Comments
 (0)