diff --git a/requirements.txt b/requirements.txt index 228b8d4..efbfed8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +sqlalchemy>=2.0.0 pytest==8.4.1 pytest-cov==6.2.1 pdoc==15.0.4 diff --git a/stat_log_db/pyproject.toml b/stat_log_db/pyproject.toml index efd8321..f23b15c 100644 --- a/stat_log_db/pyproject.toml +++ b/stat_log_db/pyproject.toml @@ -9,6 +9,7 @@ description = "" readme = "README.md" requires-python = ">=3.12.10" dependencies = [ + "sqlalchemy>=2.0.0" ] [project.optional-dependencies] diff --git a/stat_log_db/src/stat_log_db/cli.py b/stat_log_db/src/stat_log_db/cli.py index 55189b7..d9a872b 100644 --- a/stat_log_db/src/stat_log_db/cli.py +++ b/stat_log_db/src/stat_log_db/cli.py @@ -2,7 +2,7 @@ # import sys # from .parser import create_parser -from .db import MemDB # , FileDB, Database, BaseConnection +from .db import Database # , MemDB, FileDB, BaseConnection def main(): @@ -18,7 +18,7 @@ def main(): # print(f"{args=}") - sl_db = MemDB({ + sl_db = Database({ "is_mem": True, "fkey_constraint": True }) diff --git a/stat_log_db/src/stat_log_db/db.py b/stat_log_db/src/stat_log_db/db.py index 6b29ea9..445eb25 100644 --- a/stat_log_db/src/stat_log_db/db.py +++ b/stat_log_db/src/stat_log_db/db.py @@ -4,6 +4,10 @@ import uuid from typing import Any +from sqlalchemy import create_engine, text, Engine, Connection as SQLAConnection, MetaData, Table, Column +from sqlalchemy import Integer, String, Text, Boolean, Float, DateTime, LargeBinary +from sqlalchemy.engine import make_url +from sqlalchemy.pool import StaticPool, QueuePool from .exceptions import raise_auto_arg_type_error @@ -28,8 +32,8 @@ def __init__(self, options: dict[str, Any] = {}): self._db_name: str = options.get("db_name", str(uuid.uuid4())) self._db_file_name: str = ":memory:" if self._in_memory else self._db_name.replace(" ", "_") self._fkey_constraint: bool = options.get("fkey_constraint", True) - # Keep track of active connections (to ensure that they are closed) - self._connections: dict[str, BaseConnection] = dict() + # Create SQLAlchemy Engine with appropriate connection pooling + self._engine = self._create_engine() @property def name(self) -> str: @@ -50,95 +54,56 @@ def is_file(self) -> bool: @property def fkey_constraint(self) -> bool: return self._fkey_constraint - - def check_connection_integrity(self, connection: 'str | BaseConnection', skip_registry_type_check: bool = False): - """ - Check the integrity of a given connection's registration. - The connection to be checked can be passed as an UID string or a connection object (instance of BaseConnection). - """ - if not isinstance(skip_registry_type_check, bool): - raise_auto_arg_type_error("skip_registry_type_check") - connection_is_uid_str = isinstance(connection, str) - connection_is_obj = isinstance(connection, BaseConnection) - if (not connection_is_uid_str) and (not connection_is_obj): - raise_auto_arg_type_error("connection") - if self._connections is None or len(self._connections) == 0: - raise ValueError(f"Connection {connection.uid if connection_is_obj else connection} is not registered, as Connection Registry contains no connections.") - # Check that the registry is of the expected type - if (not skip_registry_type_check) and (not isinstance(self._connections, dict)): - raise TypeError(f"Expected connection registry to be a dictionary but it was {type(self._connections).__name__}") - # If the passed-in connection is a uid string, - # search the registry keys for that uid string. - # Check that a matching connection is found, - # that it has a valid UID, and that it is registered - # under the uid that it has (registry key = found connection's uid). - if connection_is_uid_str: - if len(connection) == 0: - raise ValueError("Connection UID string is empty.") - found_connection = self._connections.get(connection, None) - if found_connection is None: - raise ValueError(f"Connection '{connection}' is not registered.") - if not isinstance(found_connection.uid, str): - raise TypeError(f"Expected the found connection's uid to be str, got {type(found_connection.uid).__name__} instead.") - if len(found_connection.uid) == 0: - raise ValueError("Found connection's uid string is empty.") - if found_connection.uid != connection: - raise ValueError(f"Connection '{connection}' is registered under non-matching uid: {found_connection.uid}") - # If the passed-in connection is a BaseConnection object, - # check that it has a valid uid and that it's UID is in the registry - elif connection_is_obj: - if not isinstance(connection.uid, str): - raise TypeError(f"Expected the connection's uid to be str, got {type(connection.uid).__name__} instead.") - if connection.uid not in self._connections: - raise ValueError(f"Connection '{connection.uid}' is not registered, or is registered under the wrong uid.") - - def check_connection_registry_integrity(self, skip_registry_type_check: bool = False): - """ - Check the integrity of the connection registry. - If not all connections are registered, no error is raised. - """ - if not isinstance(skip_registry_type_check, bool): - raise_auto_arg_type_error("skip_registry_type_check") - # Check that the registry is of the expected type - if (not skip_registry_type_check) and (not isinstance(self._connections, dict)): - raise TypeError(f"Expected connection registry to be a dictionary but it was {type(self._connections).__name__}") - # If there are no connections, nothing to check - if len(self._connections) == 0: - return - # Check that all registered connections are registered under a UID of the correct type and are instances of BaseConnection - if any((not isinstance(uid, str)) or (not isinstance(conn, BaseConnection)) for uid, conn in self._connections.items()): - raise TypeError("All connections must be registered by their UID string and be instances of BaseConnection.") - # Perform individual connection integrity checks - for uid in self._connections.keys(): - self.check_connection_integrity(uid, skip_registry_type_check=True) # Registry type already checked - - def _register_connection(self): - """ - Creates a new database connection object and registers it. - Does not open the connection. - """ - connection = BaseConnection(self) - self._connections[connection.uid] = connection - self.check_connection_integrity(connection) - return connection - - def _unregister_connection(self, connection: 'str | BaseConnection'): + + @property + def engine(self) -> Engine: + """Get the SQLAlchemy Engine for this database.""" + return self._engine + + def _create_engine(self) -> Engine: + """Create SQLAlchemy Engine with appropriate configuration for SQLite.""" + if self._in_memory: + # For in-memory databases, use StaticPool to ensure single connection + # and prevent the database from being destroyed when connections close + url = "sqlite:///:memory:" + engine = create_engine( + url, + poolclass=StaticPool, + pool_pre_ping=True, + connect_args={ + "check_same_thread": False, # Allow sharing between threads + "isolation_level": None, # Use autocommit mode + }, + echo=False # Set to True for SQL debugging + ) + else: + # For file databases, use QueuePool for better concurrency + url = f"sqlite:///{self._db_file_name}" + engine = create_engine( + url, + poolclass=QueuePool, + pool_size=5, + max_overflow=10, + pool_pre_ping=True, + connect_args={ + "check_same_thread": False, + "isolation_level": None, + }, + echo=False + ) + return engine + + def create_connection(self) -> 'BaseConnection': """ - Unregister a database connection object. - Does not close it. + Creates a new database connection object using SQLAlchemy Engine. + SQLAlchemy handles connection pooling and lifecycle management. """ - connection_is_obj = isinstance(connection, BaseConnection) - if (not isinstance(connection, str)) and (not connection_is_obj): - raise_auto_arg_type_error("connection") - connection_uid_str = connection.uid if connection_is_obj else connection - self.check_connection_integrity(connection_uid_str) - # TODO: consider implementing garbage collector ref-count check - del self._connections[connection_uid_str] + return BaseConnection(self) def init_db(self, commit_fkey: bool = True) -> 'BaseConnection': if not isinstance(commit_fkey, bool): raise_auto_arg_type_error("commit_fkey") - connection = self._register_connection() + connection = self.create_connection() connection.open() connection.enforce_foreign_key_constraints(commit_fkey) return connection @@ -149,51 +114,29 @@ def init_db_auto_close(self): # don't bother to commit fkey constraint because close() will commit before connection closure connection = self.init_db(False) connection.close() - self._unregister_connection(connection.uid) + # SQLAlchemy automatically handles connection cleanup def close_db(self): - uids = [] - self.check_connection_registry_integrity() - for uid, connection in self._connections.items(): - connection.close() - uids.append(uid) - for uid in uids: - self._unregister_connection(uid) - if not len(self._connections) == 0: - raise RuntimeError("Not all connections were closed properly.") - self._connections = dict() - - -class MemDB(Database): - def __init__(self, options: dict[str, Any] = {}): - super().__init__(options=options) - if not self.in_memory: - raise ValueError("MemDB can only be used for in-memory databases.") + """Close the database and dispose of the SQLAlchemy engine to clean up connection pool.""" + if hasattr(self, '_engine') and self._engine is not None: + self._engine.dispose() - def check_connection_registry_integrity(self, skip_registry_type_check: bool = False): - """ - Check the integrity of the connection registry. - Implements early raise if more than one connection is found, - since in-memory databases can only have one connection. - """ - if not isinstance(skip_registry_type_check, bool): - raise_auto_arg_type_error("skip_registry_type_check") - if not skip_registry_type_check: - if not isinstance(self._connections, dict): - raise TypeError(f"Expected connection registry to be a dictionary but it was {type(self._connections).__name__}") - if (num_connections := len(self._connections)) > 1: - raise ValueError(f"In-memory databases can only have one active connection Found {num_connections}.") - return super().check_connection_registry_integrity(skip_registry_type_check=True) # Registry type already checked - def init_db_auto_close(self): - raise ValueError("In-memory databases cease to exist upon closure.") +# class MemDB(Database): +# def __init__(self, options: dict[str, Any] = {}): +# super().__init__(options=options) +# if not self.in_memory: +# raise ValueError("MemDB can only be used for in-memory databases.") +# def init_db_auto_close(self): +# raise ValueError("In-memory databases cease to exist upon closure.") -class FileDB(Database): - def __init__(self, options: dict[str, Any] = {}): - super().__init__(options=options) - if not self.is_file: - raise ValueError("FileDB can only be used for file-based databases.") + +# class FileDB(Database): +# def __init__(self, options: dict[str, Any] = {}): +# super().__init__(options=options) +# if not self.is_file: +# raise ValueError("FileDB can only be used for file-based databases.") class BaseConnection: @@ -201,9 +144,7 @@ def __init__(self, db: Database): if not isinstance(db, Database): raise_auto_arg_type_error("db") self._db: Database = db - self._id = str(uuid.uuid4()) - self._connection: sqlite3.Connection | None = None - self._cursor: sqlite3.Cursor | None = None + self._connection: SQLAConnection | None = None @property def db_name(self): @@ -225,72 +166,64 @@ def db_is_file(self): def db_fkey_constraint(self): return self._db._fkey_constraint - @property - def uid(self): - # TODO: Hash together the uuid, db_name, and possibly also the location in memory to ensure uniqueness? - return self._id - - @property - def registered(self): - self._db.check_connection_integrity(self) # raises error if not registered - return True - @property def connection(self): if self._connection is None: raise RuntimeError("Connection is not open.") - if not isinstance(self._connection, sqlite3.Connection): - raise TypeError(f"Expected self._connection to be sqlite3.Connection, got {type(self._connection).__name__} instead.") + if not isinstance(self._connection, SQLAConnection): + raise TypeError(f"Expected self._connection to be SQLAlchemy Connection, got {type(self._connection).__name__} instead.") return self._connection - @property - def cursor(self): - if self._cursor is None: - raise RuntimeError("Cursor is not open.") - if not isinstance(self._cursor, sqlite3.Cursor): - raise TypeError(f"Expected self._cursor to be sqlite3.Cursor, got {type(self._cursor).__name__} instead.") - return self._cursor - def enforce_foreign_key_constraints(self, commit: bool = True): if not isinstance(commit, bool): raise_auto_arg_type_error("commit") if self.db_fkey_constraint: - self.cursor.execute("PRAGMA foreign_keys = ON;") + self.connection.execute(text("PRAGMA foreign_keys = ON;")) if commit: self.connection.commit() def _open(self): - self._connection = sqlite3.connect(self.db_file_name) - self._cursor = self._connection.cursor() + """Open a new SQLAlchemy connection from the engine.""" + self._connection = self._db.engine.connect() def open(self): - if isinstance(self._connection, sqlite3.Connection): + if self._connection is not None: raise RuntimeError("Connection is already open.") - if not (self._connection is None): - raise TypeError(f"Expected self._connection to be None, got {type(self._connection).__name__} instead.") self._open() def _close(self): - self.cursor.close() - self._cursor = None - self.connection.close() - self._connection = None + """Close the SQLAlchemy connection.""" + if self._connection is not None: + self._connection.close() + self._connection = None def close(self): - self.connection.commit() + if self._connection is not None: + self.connection.commit() self._close() def _execute(self, query: str, parameters: tuple = ()): """ - Execute a SQL query with the given parameters. + Execute a SQL query with the given parameters using SQLAlchemy. Performs no checks/validation. Prefer `execute` unless you need raw access. """ - result = self.cursor.execute(query, parameters) + # Convert parameters tuple to dict for SQLAlchemy if needed + if parameters: + # Create numbered parameter dict for SQLAlchemy + param_dict = {f"param_{i}": param for i, param in enumerate(parameters)} + # Replace ? placeholders with :param_N format + param_query = query + for i in range(len(parameters)): + param_query = param_query.replace("?", f":param_{i}", 1) + result = self.connection.execute(text(param_query), param_dict) + else: + result = self.connection.execute(text(query)) return result def execute(self, query: str, parameters: tuple | None = None): """ - Execute a SQL query with the given parameters. + Execute a SQL query with the given parameters using SQLAlchemy text() construct. + Returns the SQLAlchemy Result object for chaining fetchone/fetchall operations. """ # Validate query and parameters if not isinstance(query, str): @@ -305,125 +238,296 @@ def execute(self, query: str, parameters: tuple | None = None): # If `params` points to an object that isn't a tuple or None (per previous condition), raise a TypeError elif not isinstance(params, tuple): raise_auto_arg_type_error("parameters") - # Execute query with `params` - result = self._execute(query, params) - return result + # Execute query with `params` and store the result for potential fetching + self._last_result = self._execute(query, params) + return self._last_result def commit(self): self.connection.commit() def fetchone(self): - return self.cursor.fetchone() + """Fetch one row from the last executed query result.""" + if not hasattr(self, '_last_result') or self._last_result is None: + raise RuntimeError("No query has been executed. Call execute() first.") + return self._last_result.fetchone() def fetchall(self): - return self.cursor.fetchall() + """Fetch all rows from the last executed query result.""" + if not hasattr(self, '_last_result') or self._last_result is None: + raise RuntimeError("No query has been executed. Call execute() first.") + return self._last_result.fetchall() - def _validate_sql_identifier(self, identifier: str, identifier_type: str = "identifier") -> str: + def _validate_identifier(self, identifier: str, identifier_type: str = "identifier") -> str: """ - Validate and sanitize SQL identifiers (table names, column names) to prevent SQL injection. - Args: - identifier: The identifier to validate - identifier_type: Type of identifier for error messages (e.g., "table name", "column name") - Returns: - The validated identifier - Raises: - ValueError: If the identifier is invalid or potentially dangerous + Validate SQL identifiers with reserved word checking and basic format validation. + SQLAlchemy handles parameterization, but we still validate identifier safety. """ if not isinstance(identifier, str): raise TypeError(f"SQL {identifier_type} must be a string, got {type(identifier).__name__}") if len(identifier) == 0: raise ValueError(f"SQL {identifier_type} cannot be empty") - # Check for valid identifier pattern: starts with letter/underscore, contains only alphanumeric/underscore + + # Basic format validation if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier): raise ValueError(f"Invalid SQL {identifier_type}: '{identifier}'. Must start with letter or underscore and contain only letters, numbers, and underscores.") - # Check against SQLite reserved words (common ones that could cause issues) - reserved_words = { - 'abort', 'action', 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', - 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', 'cascade', 'case', - 'cast', 'check', 'collate', 'column', 'commit', 'conflict', 'constraint', 'create', - 'cross', 'current', 'current_date', 'current_time', 'current_timestamp', 'database', - 'default', 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', 'do', - 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', 'exists', 'explain', - 'fail', 'filter', 'following', 'for', 'foreign', 'from', 'full', 'glob', 'group', - 'having', 'if', 'ignore', 'immediate', 'in', 'index', 'indexed', 'initially', 'inner', - 'insert', 'instead', 'intersect', 'into', 'is', 'isnull', 'join', 'key', 'left', - 'like', 'limit', 'match', 'natural', 'no', 'not', 'notnull', 'null', 'of', 'offset', - 'on', 'or', 'order', 'outer', 'over', 'partition', 'plan', 'pragma', 'preceding', - 'primary', 'query', 'raise', 'range', 'recursive', 'references', 'regexp', 'reindex', - 'release', 'rename', 'replace', 'restrict', 'right', 'rollback', 'row', 'rows', - 'savepoint', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to', 'transaction', - 'trigger', 'unbounded', 'union', 'unique', 'update', 'using', 'vacuum', 'values', - 'view', 'virtual', 'when', 'where', 'window', 'with', 'without' + + # Check against SQL reserved words + sql_reserved_words = { + 'ABORT', 'ACTION', 'ADD', 'AFTER', 'ALL', 'ALTER', 'ANALYZE', 'AND', 'AS', 'ASC', + 'ATTACH', 'AUTOINCREMENT', 'BEFORE', 'BEGIN', 'BETWEEN', 'BY', 'CASCADE', 'CASE', + 'CAST', 'CHECK', 'COLLATE', 'COLUMN', 'COMMIT', 'CONFLICT', 'CONSTRAINT', 'CREATE', + 'CROSS', 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'DATABASE', 'DEFAULT', + 'DEFERRABLE', 'DEFERRED', 'DELETE', 'DESC', 'DETACH', 'DISTINCT', 'DROP', 'EACH', + 'ELSE', 'END', 'ESCAPE', 'EXCEPT', 'EXCLUSIVE', 'EXISTS', 'EXPLAIN', 'FAIL', 'FOR', + 'FOREIGN', 'FROM', 'FULL', 'GLOB', 'GROUP', 'HAVING', 'IF', 'IGNORE', 'IMMEDIATE', + 'IN', 'INDEX', 'INDEXED', 'INITIALLY', 'INNER', 'INSERT', 'INSTEAD', 'INTERSECT', + 'INTO', 'IS', 'ISNULL', 'JOIN', 'KEY', 'LEFT', 'LIKE', 'LIMIT', 'MATCH', 'NATURAL', + 'NO', 'NOT', 'NOTNULL', 'NULL', 'OF', 'OFFSET', 'ON', 'OR', 'ORDER', 'OUTER', 'PLAN', + 'PRAGMA', 'PRIMARY', 'QUERY', 'RAISE', 'RECURSIVE', 'REFERENCES', 'REGEXP', 'REINDEX', + 'RELEASE', 'RENAME', 'REPLACE', 'RESTRICT', 'RIGHT', 'ROLLBACK', 'ROW', 'SAVEPOINT', + 'SELECT', 'SET', 'TABLE', 'TEMP', 'TEMPORARY', 'THEN', 'TO', 'TRANSACTION', 'TRIGGER', + 'UNION', 'UNIQUE', 'UPDATE', 'USING', 'VACUUM', 'VALUES', 'VIEW', 'VIRTUAL', 'WHEN', + 'WHERE', 'WITH', 'WITHOUT' } - if identifier.lower() in reserved_words: - raise ValueError(f"SQL {identifier_type} '{identifier}' is a reserved word and cannot be used") + + if identifier.upper() in sql_reserved_words: + raise ValueError(f"SQL {identifier_type} '{identifier}' is a reserved word") + return identifier - def _escape_sql_identifier(self, identifier: str) -> str: + def _validate_column_type(self, col_type: str) -> str: + """ + Validate SQLite column type specification. + Returns the normalized (uppercase) column type. + """ + if not isinstance(col_type, str): + raise TypeError(f"Column type must be a string, got {type(col_type).__name__}") + + # Normalize to uppercase + normalized_type = col_type.upper().strip() + + # Define allowed SQLite types + allowed_types = { + 'TEXT', 'INTEGER', 'REAL', 'BLOB', 'NUMERIC', + 'VARCHAR', 'CHAR', 'NVARCHAR', 'NCHAR', + 'CLOB', 'DATE', 'DATETIME', 'TIMESTAMP', + 'BOOLEAN', 'DECIMAL', 'DOUBLE', 'FLOAT', + 'INT', 'BIGINT', 'SMALLINT', 'TINYINT' + } + + # Extract base type (before any parentheses) + base_type_match = re.match(r'^([A-Z]+)', normalized_type) + if not base_type_match: + raise ValueError(f"Invalid column type format: '{col_type}'") + + base_type = base_type_match.group(1) + if base_type not in allowed_types: + raise ValueError(f"Unsupported column type: '{col_type}'. Must be one of: {', '.join(sorted(allowed_types))}") + + # Validate full type specification format (allowing precision/length specifiers) + if not re.match(r'^[A-Z]+(\([0-9,\s]+\))?$', normalized_type): + raise ValueError(f"Invalid column type format: '{col_type}'. Use format like 'VARCHAR(50)' or 'DECIMAL(10,2)'") + + return normalized_type + + def _build_column_definition(self, col_name: str, col_type: str) -> str: """ - Escape SQL identifier by wrapping in double quotes and escaping any internal quotes. - This should only be used after validation. + Build a column definition string with proper validation and escaping. + Returns formatted column definition for CREATE TABLE statement. """ - # Escape any double quotes in the identifier by doubling them - escaped = identifier.replace('"', '""') - return f'"{escaped}"' + validated_col_name = self._validate_identifier(col_name, "column name") + validated_col_type = self._validate_column_type(col_type) + + # Use SQLite standard double-quote escaping for identifiers + escaped_col_name = f'"{validated_col_name}"' + + return f"{escaped_col_name} {validated_col_type}" def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_table: bool = True, raise_if_exists: bool = True): - # Validate table_name argument + """ + Create a new table using SQLAlchemy with proper validation and escaping. + + Args: + table_name: Name of the table to create + columns: List of (column_name, column_type) tuples + temp_table: Whether to create a temporary table + raise_if_exists: Whether to raise an error if table already exists + """ + # Validate arguments if not isinstance(table_name, str): raise_auto_arg_type_error("table_name") if len(table_name) == 0: - raise ValueError("'table_name' argument of create_table cannot be an empty string!") - # Validate temp_table argument + raise ValueError("'table_name' argument cannot be an empty string") if not isinstance(temp_table, bool): raise_auto_arg_type_error("temp_table") if not isinstance(raise_if_exists, bool): raise_auto_arg_type_error("raise_if_exists") - # Validate columns argument - if (not isinstance(columns, list)) or (not all( + if not isinstance(columns, list) or not all( isinstance(col, tuple) and len(col) == 2 - and isinstance(col[0], str) - and isinstance(col[1], str) - for col in columns)): + and isinstance(col[0], str) and isinstance(col[1], str) + for col in columns + ): raise_auto_arg_type_error("columns") - # Validate and sanitize table name - validated_table_name = self._validate_sql_identifier(table_name, "table name") - escaped_table_name = self._escape_sql_identifier(validated_table_name) - # Check if table already exists using parameterized query + + # Validate and normalize table name + validated_table_name = self._validate_identifier(table_name, "table name") + escaped_table_name = f'"{validated_table_name}"' + + # Check if table already exists using SQLAlchemy parameterized query if raise_if_exists: - self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,)) - if self.cursor.fetchone() is not None: - raise ValueError(f"Table '{validated_table_name}' already exists.") - # Validate and construct columns portion of query - validated_columns = [] + check_query = text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name") + result = self.connection.execute(check_query, {"table_name": validated_table_name}) + if result.fetchone() is not None: + raise ValueError(f"Table '{validated_table_name}' already exists") + + # Build column definitions using helper method + column_definitions = [] for col_name, col_type in columns: - # Validate column name - validated_col_name = self._validate_sql_identifier(col_name, "column name") - escaped_col_name = self._escape_sql_identifier(validated_col_name) - # Validate column type - allow only safe, known SQLite types - allowed_types = { - 'TEXT', 'INTEGER', 'REAL', 'BLOB', 'NUMERIC', - 'VARCHAR', 'CHAR', 'NVARCHAR', 'NCHAR', - 'CLOB', 'DATE', 'DATETIME', 'TIMESTAMP', - 'BOOLEAN', 'DECIMAL', 'DOUBLE', 'FLOAT', - 'INT', 'BIGINT', 'SMALLINT', 'TINYINT' - } - # Allow type specifications with length/precision (e.g., VARCHAR(50), DECIMAL(10,2)) - base_type = re.match(r'^([A-Z]+)', col_type.upper()) - if not base_type or base_type.group(1) not in allowed_types: - raise ValueError(f"Unsupported column type: '{col_type}'. Must be one of: {', '.join(sorted(allowed_types))}") - # Basic validation for type specification format - if not re.match(r'^[A-Z]+(\([0-9,\s]+\))?$', col_type.upper()): - raise ValueError(f"Invalid column type format: '{col_type}'") - validated_columns.append(f"{escaped_col_name} {col_type.upper()}") - columns_qstr = ",\n ".join(validated_columns) - # Assemble full query with escaped identifiers + column_def = self._build_column_definition(col_name, col_type) + column_definitions.append(column_def) + + # Format column definitions for query + columns_clause = ",\n ".join(column_definitions) + + # Build CREATE TABLE statement temp_keyword = " TEMPORARY" if temp_table else "" - query = f"""CREATE{temp_keyword} TABLE IF NOT EXISTS {escaped_table_name} ( + create_query = f"""CREATE{temp_keyword} TABLE IF NOT EXISTS {escaped_table_name} ( id INTEGER PRIMARY KEY AUTOINCREMENT, - {columns_qstr} - );""" - self.execute(query) + {columns_clause} + )""" + + # Execute using SQLAlchemy's text() construct for safe execution + self.connection.execute(text(create_query)) + + def _map_sqlite_type_to_sqlalchemy(self, sqlite_type: str): + """ + Map SQLite type strings to SQLAlchemy column types. + Returns appropriate SQLAlchemy column type class. + """ + # Normalize the type string + normalized_type = sqlite_type.upper().strip() + base_type_match = re.match(r'^([A-Z]+)', normalized_type) + if not base_type_match: + raise ValueError(f"Invalid SQLite type format: '{sqlite_type}'") + base_type = base_type_match.group(1) + + # SQLite to SQLAlchemy type mapping + type_mapping = { + 'TEXT': Text, + 'VARCHAR': String, + 'CHAR': String, + 'NVARCHAR': String, + 'NCHAR': String, + 'CLOB': Text, + 'INTEGER': Integer, + 'INT': Integer, + 'BIGINT': Integer, + 'SMALLINT': Integer, + 'TINYINT': Integer, + 'REAL': Float, + 'DOUBLE': Float, + 'FLOAT': Float, + 'NUMERIC': Float, + 'DECIMAL': Float, + 'BOOLEAN': Boolean, + 'DATE': DateTime, + 'DATETIME': DateTime, + 'TIMESTAMP': DateTime, + 'BLOB': LargeBinary, + } + + if base_type not in type_mapping: + raise ValueError(f"Cannot map SQLite type '{sqlite_type}' to SQLAlchemy type") + + sqlalchemy_type = type_mapping[base_type] + + # Handle length specifications for string types + if base_type in ('VARCHAR', 'CHAR', 'NVARCHAR', 'NCHAR') and '(' in normalized_type: + length_match = re.search(r'\((\d+)\)', normalized_type) + if length_match: + length = int(length_match.group(1)) + return sqlalchemy_type(length) + + return sqlalchemy_type() + + def create_table_with_sqlalchemy_ddl(self, table_name: str, columns: list[tuple[str, str]], temp_table: bool = True, raise_if_exists: bool = True): + """ + Create a table using SQLAlchemy's Table and MetaData objects for enhanced DDL capabilities. + This provides better type safety and integration with SQLAlchemy's ORM features. + + Args: + table_name: Name of the table to create + columns: List of (column_name, column_type) tuples + temp_table: Whether to create a temporary table + raise_if_exists: Whether to raise an error if table already exists + """ + # Validate arguments (reuse validation from create_table) + if not isinstance(table_name, str) or len(table_name) == 0: + raise ValueError("table_name must be a non-empty string") + if not isinstance(temp_table, bool): + raise_auto_arg_type_error("temp_table") + if not isinstance(raise_if_exists, bool): + raise_auto_arg_type_error("raise_if_exists") + if not isinstance(columns, list) or not all( + isinstance(col, tuple) and len(col) == 2 + and isinstance(col[0], str) and isinstance(col[1], str) + for col in columns + ): + raise_auto_arg_type_error("columns") + + # Validate table name + validated_table_name = self._validate_identifier(table_name, "table name") + + # Check if table exists + if raise_if_exists: + check_query = text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name") + result = self.connection.execute(check_query, {"table_name": validated_table_name}) + if result.fetchone() is not None: + raise ValueError(f"Table '{validated_table_name}' already exists") + + # Create MetaData object + metadata = MetaData() + + # Build SQLAlchemy Column objects + sqlalchemy_columns = [ + Column('id', Integer, primary_key=True, autoincrement=True) + ] + + for col_name, col_type in columns: + validated_col_name = self._validate_identifier(col_name, "column name") + validated_col_type = self._validate_column_type(col_type) + + # Map to SQLAlchemy type + sqlalchemy_type = self._map_sqlite_type_to_sqlalchemy(validated_col_type) + sqlalchemy_columns.append(Column(validated_col_name, sqlalchemy_type)) + + # Create Table object + if temp_table: + # For temporary tables, fall back to raw SQL since SQLAlchemy doesn't have direct support + temp_keyword = " TEMPORARY" + column_defs = [] + for col in sqlalchemy_columns: + if col.name == 'id': + column_defs.append(f'"{col.name}" {col.type.compile(self._db.engine.dialect)} PRIMARY KEY AUTOINCREMENT') + else: + column_defs.append(f'"{col.name}" {col.type.compile(self._db.engine.dialect)}') + + columns_clause = ",\n ".join(column_defs) + escaped_table_name = f'"{validated_table_name}"' + create_query = f"""CREATE{temp_keyword} TABLE IF NOT EXISTS {escaped_table_name} ( + {columns_clause} + )""" + self.connection.execute(text(create_query)) + else: + # For regular tables, use SQLAlchemy's DDL capabilities + table = Table( + validated_table_name, + metadata, + *sqlalchemy_columns + ) + + # Create the table using SQLAlchemy DDL + metadata.create_all(self._db.engine, tables=[table], checkfirst=not raise_if_exists) def drop_table(self, table_name: str, raise_if_not_exists: bool = False): # Validate table_name argument @@ -433,16 +537,23 @@ def drop_table(self, table_name: str, raise_if_not_exists: bool = False): raise ValueError("'table_name' argument of drop_table cannot be an empty string!") if not isinstance(raise_if_not_exists, bool): raise_auto_arg_type_error("raise_if_not_exists") - # Validate and sanitize table name - validated_table_name = self._validate_sql_identifier(table_name, "table name") - escaped_table_name = self._escape_sql_identifier(validated_table_name) - # Check if table exists using parameterized query + + # Validate table name using basic validation + validated_table_name = self._validate_identifier(table_name, "table name") + + # Check if table exists using SQLAlchemy parameterized query if raise_if_not_exists: - self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,)) - if self.cursor.fetchone() is None: + result = self.connection.execute( + text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name"), + {"table_name": validated_table_name} + ) + if result.fetchone() is None: raise ValueError(f"Table '{validated_table_name}' does not exist.") - # Execute DROP statement with escaped identifier - self.cursor.execute(f"DROP TABLE IF EXISTS {escaped_table_name};") + + # Execute DROP TABLE query with proper identifier escaping + escaped_table_name = f'"{validated_table_name}"' + query = f"DROP TABLE IF EXISTS {escaped_table_name}" + self.connection.execute(text(query)) # def read(self): # pass diff --git a/stat_log_db/tests/test_sql_injection.py b/stat_log_db/tests/test_sql_injection.py index 447ba96..bf9bd37 100644 --- a/stat_log_db/tests/test_sql_injection.py +++ b/stat_log_db/tests/test_sql_injection.py @@ -6,7 +6,7 @@ import sys from pathlib import Path -from stat_log_db.db import MemDB +from stat_log_db.db import Database # Add the src directory to the path to import the module @@ -17,7 +17,7 @@ @pytest.fixture def mem_db(): """Create a test in-memory database and clean up after tests.""" - sl_db = MemDB({ + sl_db = Database({ "is_mem": True, "fkey_constraint": True })