From cf65e95f001ae4bc2684dc2addc2fda41f862064 Mon Sep 17 00:00:00 2001 From: bjthres1 Date: Sat, 6 Sep 2025 17:46:44 -0400 Subject: [PATCH 1/3] feat(db): [WIP] integrate SQLAlchemy for db mgmt & connection pooling This update introduces SQLAlchemy for improved database handling, including connection pooling and enhanced query execution. The Database class now utilizes SQLAlchemy's Engine for managing connections, allowing for better concurrency and resource management. BREAKING CHANGE: The connection handling has been refactored to use SQLAlchemy, which may affect existing database interaction methods. --- requirements.txt | 1 + stat_log_db/pyproject.toml | 1 + stat_log_db/src/stat_log_db/db.py | 125 +++++++++++++++++++++--------- 3 files changed, 91 insertions(+), 36 deletions(-) diff --git a/requirements.txt b/requirements.txt index 228b8d4..efbfed8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +sqlalchemy>=2.0.0 pytest==8.4.1 pytest-cov==6.2.1 pdoc==15.0.4 diff --git a/stat_log_db/pyproject.toml b/stat_log_db/pyproject.toml index efd8321..f23b15c 100644 --- a/stat_log_db/pyproject.toml +++ b/stat_log_db/pyproject.toml @@ -9,6 +9,7 @@ description = "" readme = "README.md" requires-python = ">=3.12.10" dependencies = [ + "sqlalchemy>=2.0.0" ] [project.optional-dependencies] diff --git a/stat_log_db/src/stat_log_db/db.py b/stat_log_db/src/stat_log_db/db.py index 6b29ea9..3b18395 100644 --- a/stat_log_db/src/stat_log_db/db.py +++ b/stat_log_db/src/stat_log_db/db.py @@ -4,6 +4,9 @@ import uuid from typing import Any +from sqlalchemy import create_engine, text, Engine, Connection as SQLAConnection +from sqlalchemy.engine import make_url +from sqlalchemy.pool import StaticPool, QueuePool from .exceptions import raise_auto_arg_type_error @@ -28,9 +31,46 @@ def __init__(self, options: dict[str, Any] = {}): self._db_name: str = options.get("db_name", str(uuid.uuid4())) self._db_file_name: str = ":memory:" if self._in_memory else self._db_name.replace(" ", "_") self._fkey_constraint: bool = options.get("fkey_constraint", True) - # Keep track of active connections (to ensure that they are closed) + + # Create SQLAlchemy Engine with appropriate connection pooling + self._engine = self._create_engine() + + # Keep track of active connections for compatibility (but SQLAlchemy handles the real pooling) self._connections: dict[str, BaseConnection] = dict() + def _create_engine(self) -> Engine: + """Create SQLAlchemy Engine with appropriate configuration for SQLite.""" + if self._in_memory: + # For in-memory databases, use StaticPool to ensure single connection + # and prevent the database from being destroyed when connections close + url = "sqlite:///:memory:" + engine = create_engine( + url, + poolclass=StaticPool, + pool_pre_ping=True, + connect_args={ + "check_same_thread": False, # Allow sharing between threads + "isolation_level": None, # Use autocommit mode + }, + echo=False # Set to True for SQL debugging + ) + else: + # For file databases, use QueuePool for better concurrency + url = f"sqlite:///{self._db_file_name}" + engine = create_engine( + url, + poolclass=QueuePool, + pool_size=5, + max_overflow=10, + pool_pre_ping=True, + connect_args={ + "check_same_thread": False, + "isolation_level": None, + }, + echo=False + ) + return engine + @property def name(self) -> str: return self._db_name @@ -50,6 +90,11 @@ def is_file(self) -> bool: @property def fkey_constraint(self) -> bool: return self._fkey_constraint + + @property + def engine(self) -> Engine: + """Get the SQLAlchemy Engine for this database.""" + return self._engine def check_connection_integrity(self, connection: 'str | BaseConnection', skip_registry_type_check: bool = False): """ @@ -115,7 +160,7 @@ def check_connection_registry_integrity(self, skip_registry_type_check: bool = F def _register_connection(self): """ Creates a new database connection object and registers it. - Does not open the connection. + The connection uses SQLAlchemy Engine for connection management. """ connection = BaseConnection(self) self._connections[connection.uid] = connection @@ -162,6 +207,10 @@ def close_db(self): if not len(self._connections) == 0: raise RuntimeError("Not all connections were closed properly.") self._connections = dict() + + # Dispose of the SQLAlchemy engine to clean up connection pool + if hasattr(self, '_engine') and self._engine is not None: + self._engine.dispose() class MemDB(Database): @@ -202,8 +251,7 @@ def __init__(self, db: Database): raise_auto_arg_type_error("db") self._db: Database = db self._id = str(uuid.uuid4()) - self._connection: sqlite3.Connection | None = None - self._cursor: sqlite3.Cursor | None = None + self._connection: SQLAConnection | None = None @property def db_name(self): @@ -239,58 +287,59 @@ def registered(self): def connection(self): if self._connection is None: raise RuntimeError("Connection is not open.") - if not isinstance(self._connection, sqlite3.Connection): - raise TypeError(f"Expected self._connection to be sqlite3.Connection, got {type(self._connection).__name__} instead.") + if not isinstance(self._connection, SQLAConnection): + raise TypeError(f"Expected self._connection to be SQLAlchemy Connection, got {type(self._connection).__name__} instead.") return self._connection - @property - def cursor(self): - if self._cursor is None: - raise RuntimeError("Cursor is not open.") - if not isinstance(self._cursor, sqlite3.Cursor): - raise TypeError(f"Expected self._cursor to be sqlite3.Cursor, got {type(self._cursor).__name__} instead.") - return self._cursor - def enforce_foreign_key_constraints(self, commit: bool = True): if not isinstance(commit, bool): raise_auto_arg_type_error("commit") if self.db_fkey_constraint: - self.cursor.execute("PRAGMA foreign_keys = ON;") + self.connection.execute(text("PRAGMA foreign_keys = ON;")) if commit: self.connection.commit() def _open(self): - self._connection = sqlite3.connect(self.db_file_name) - self._cursor = self._connection.cursor() + """Open a new SQLAlchemy connection from the engine.""" + self._connection = self._db.engine.connect() def open(self): - if isinstance(self._connection, sqlite3.Connection): + if self._connection is not None: raise RuntimeError("Connection is already open.") - if not (self._connection is None): - raise TypeError(f"Expected self._connection to be None, got {type(self._connection).__name__} instead.") self._open() def _close(self): - self.cursor.close() - self._cursor = None - self.connection.close() - self._connection = None + """Close the SQLAlchemy connection.""" + if self._connection is not None: + self._connection.close() + self._connection = None def close(self): - self.connection.commit() + if self._connection is not None: + self.connection.commit() self._close() def _execute(self, query: str, parameters: tuple = ()): """ - Execute a SQL query with the given parameters. + Execute a SQL query with the given parameters using SQLAlchemy. Performs no checks/validation. Prefer `execute` unless you need raw access. """ - result = self.cursor.execute(query, parameters) + # Convert parameters tuple to dict for SQLAlchemy if needed + if parameters: + # Create numbered parameter dict for SQLAlchemy + param_dict = {f"param_{i}": param for i, param in enumerate(parameters)} + # Replace ? placeholders with :param_N format + param_query = query + for i in range(len(parameters)): + param_query = param_query.replace("?", f":param_{i}", 1) + result = self.connection.execute(text(param_query), param_dict) + else: + result = self.connection.execute(text(query)) return result def execute(self, query: str, parameters: tuple | None = None): """ - Execute a SQL query with the given parameters. + Execute a SQL query with the given parameters using SQLAlchemy text() construct. """ # Validate query and parameters if not isinstance(query, str): @@ -313,10 +362,14 @@ def commit(self): self.connection.commit() def fetchone(self): - return self.cursor.fetchone() + # For SQLAlchemy, we need to get the last result from the connection + # This is a simplified approach - in practice you'd want to store the result object + raise NotImplementedError("fetchone() needs to be called on the result object returned by execute()") def fetchall(self): - return self.cursor.fetchall() + # For SQLAlchemy, we need to get the last result from the connection + # This is a simplified approach - in practice you'd want to store the result object + raise NotImplementedError("fetchall() needs to be called on the result object returned by execute()") def _validate_sql_identifier(self, identifier: str, identifier_type: str = "identifier") -> str: """ @@ -391,8 +444,8 @@ def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_tab escaped_table_name = self._escape_sql_identifier(validated_table_name) # Check if table already exists using parameterized query if raise_if_exists: - self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,)) - if self.cursor.fetchone() is not None: + result = self.connection.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name"), {"table_name": validated_table_name}) + if result.fetchone() is not None: raise ValueError(f"Table '{validated_table_name}' already exists.") # Validate and construct columns portion of query validated_columns = [] @@ -423,7 +476,7 @@ def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_tab id INTEGER PRIMARY KEY AUTOINCREMENT, {columns_qstr} );""" - self.execute(query) + self.connection.execute(text(query)) def drop_table(self, table_name: str, raise_if_not_exists: bool = False): # Validate table_name argument @@ -438,11 +491,11 @@ def drop_table(self, table_name: str, raise_if_not_exists: bool = False): escaped_table_name = self._escape_sql_identifier(validated_table_name) # Check if table exists using parameterized query if raise_if_not_exists: - self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,)) - if self.cursor.fetchone() is None: + result = self.connection.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name"), {"table_name": validated_table_name}) + if result.fetchone() is None: raise ValueError(f"Table '{validated_table_name}' does not exist.") # Execute DROP statement with escaped identifier - self.cursor.execute(f"DROP TABLE IF EXISTS {escaped_table_name};") + self.connection.execute(text(f"DROP TABLE IF EXISTS {escaped_table_name};")) # def read(self): # pass From e6c548286d6c27bc08d8a1d88d910a0508e89feb Mon Sep 17 00:00:00 2001 From: bjthres1 Date: Tue, 9 Sep 2025 17:47:39 -0400 Subject: [PATCH 2/3] feat(db): refactor database class using SQLAlchemy - Improved database connection management using SQLAlchemy's session system. - Improved query sanitization/safety using SQLAlchemy. - Updated the CLI and test files to utilize the Database class instead of MemDB. --- stat_log_db/src/stat_log_db/cli.py | 4 +- stat_log_db/src/stat_log_db/db.py | 333 ++++++++---------------- stat_log_db/tests/test_sql_injection.py | 4 +- 3 files changed, 106 insertions(+), 235 deletions(-) diff --git a/stat_log_db/src/stat_log_db/cli.py b/stat_log_db/src/stat_log_db/cli.py index 55189b7..d9a872b 100644 --- a/stat_log_db/src/stat_log_db/cli.py +++ b/stat_log_db/src/stat_log_db/cli.py @@ -2,7 +2,7 @@ # import sys # from .parser import create_parser -from .db import MemDB # , FileDB, Database, BaseConnection +from .db import Database # , MemDB, FileDB, BaseConnection def main(): @@ -18,7 +18,7 @@ def main(): # print(f"{args=}") - sl_db = MemDB({ + sl_db = Database({ "is_mem": True, "fkey_constraint": True }) diff --git a/stat_log_db/src/stat_log_db/db.py b/stat_log_db/src/stat_log_db/db.py index 3b18395..d14e8c1 100644 --- a/stat_log_db/src/stat_log_db/db.py +++ b/stat_log_db/src/stat_log_db/db.py @@ -31,12 +31,33 @@ def __init__(self, options: dict[str, Any] = {}): self._db_name: str = options.get("db_name", str(uuid.uuid4())) self._db_file_name: str = ":memory:" if self._in_memory else self._db_name.replace(" ", "_") self._fkey_constraint: bool = options.get("fkey_constraint", True) - # Create SQLAlchemy Engine with appropriate connection pooling self._engine = self._create_engine() - - # Keep track of active connections for compatibility (but SQLAlchemy handles the real pooling) - self._connections: dict[str, BaseConnection] = dict() + + @property + def name(self) -> str: + return self._db_name + + @property + def file_name(self) -> str: + return self._db_file_name + + @property + def in_memory(self) -> bool: + return self._in_memory + + @property + def is_file(self) -> bool: + return self._is_file + + @property + def fkey_constraint(self) -> bool: + return self._fkey_constraint + + @property + def engine(self) -> Engine: + """Get the SQLAlchemy Engine for this database.""" + return self._engine def _create_engine(self) -> Engine: """Create SQLAlchemy Engine with appropriate configuration for SQLite.""" @@ -71,119 +92,17 @@ def _create_engine(self) -> Engine: ) return engine - @property - def name(self) -> str: - return self._db_name - - @property - def file_name(self) -> str: - return self._db_file_name - - @property - def in_memory(self) -> bool: - return self._in_memory - - @property - def is_file(self) -> bool: - return self._is_file - - @property - def fkey_constraint(self) -> bool: - return self._fkey_constraint - - @property - def engine(self) -> Engine: - """Get the SQLAlchemy Engine for this database.""" - return self._engine - - def check_connection_integrity(self, connection: 'str | BaseConnection', skip_registry_type_check: bool = False): - """ - Check the integrity of a given connection's registration. - The connection to be checked can be passed as an UID string or a connection object (instance of BaseConnection). - """ - if not isinstance(skip_registry_type_check, bool): - raise_auto_arg_type_error("skip_registry_type_check") - connection_is_uid_str = isinstance(connection, str) - connection_is_obj = isinstance(connection, BaseConnection) - if (not connection_is_uid_str) and (not connection_is_obj): - raise_auto_arg_type_error("connection") - if self._connections is None or len(self._connections) == 0: - raise ValueError(f"Connection {connection.uid if connection_is_obj else connection} is not registered, as Connection Registry contains no connections.") - # Check that the registry is of the expected type - if (not skip_registry_type_check) and (not isinstance(self._connections, dict)): - raise TypeError(f"Expected connection registry to be a dictionary but it was {type(self._connections).__name__}") - # If the passed-in connection is a uid string, - # search the registry keys for that uid string. - # Check that a matching connection is found, - # that it has a valid UID, and that it is registered - # under the uid that it has (registry key = found connection's uid). - if connection_is_uid_str: - if len(connection) == 0: - raise ValueError("Connection UID string is empty.") - found_connection = self._connections.get(connection, None) - if found_connection is None: - raise ValueError(f"Connection '{connection}' is not registered.") - if not isinstance(found_connection.uid, str): - raise TypeError(f"Expected the found connection's uid to be str, got {type(found_connection.uid).__name__} instead.") - if len(found_connection.uid) == 0: - raise ValueError("Found connection's uid string is empty.") - if found_connection.uid != connection: - raise ValueError(f"Connection '{connection}' is registered under non-matching uid: {found_connection.uid}") - # If the passed-in connection is a BaseConnection object, - # check that it has a valid uid and that it's UID is in the registry - elif connection_is_obj: - if not isinstance(connection.uid, str): - raise TypeError(f"Expected the connection's uid to be str, got {type(connection.uid).__name__} instead.") - if connection.uid not in self._connections: - raise ValueError(f"Connection '{connection.uid}' is not registered, or is registered under the wrong uid.") - - def check_connection_registry_integrity(self, skip_registry_type_check: bool = False): - """ - Check the integrity of the connection registry. - If not all connections are registered, no error is raised. - """ - if not isinstance(skip_registry_type_check, bool): - raise_auto_arg_type_error("skip_registry_type_check") - # Check that the registry is of the expected type - if (not skip_registry_type_check) and (not isinstance(self._connections, dict)): - raise TypeError(f"Expected connection registry to be a dictionary but it was {type(self._connections).__name__}") - # If there are no connections, nothing to check - if len(self._connections) == 0: - return - # Check that all registered connections are registered under a UID of the correct type and are instances of BaseConnection - if any((not isinstance(uid, str)) or (not isinstance(conn, BaseConnection)) for uid, conn in self._connections.items()): - raise TypeError("All connections must be registered by their UID string and be instances of BaseConnection.") - # Perform individual connection integrity checks - for uid in self._connections.keys(): - self.check_connection_integrity(uid, skip_registry_type_check=True) # Registry type already checked - - def _register_connection(self): + def create_connection(self) -> 'BaseConnection': """ - Creates a new database connection object and registers it. - The connection uses SQLAlchemy Engine for connection management. + Creates a new database connection object using SQLAlchemy Engine. + SQLAlchemy handles connection pooling and lifecycle management. """ - connection = BaseConnection(self) - self._connections[connection.uid] = connection - self.check_connection_integrity(connection) - return connection - - def _unregister_connection(self, connection: 'str | BaseConnection'): - """ - Unregister a database connection object. - Does not close it. - """ - connection_is_obj = isinstance(connection, BaseConnection) - if (not isinstance(connection, str)) and (not connection_is_obj): - raise_auto_arg_type_error("connection") - connection_uid_str = connection.uid if connection_is_obj else connection - self.check_connection_integrity(connection_uid_str) - # TODO: consider implementing garbage collector ref-count check - del self._connections[connection_uid_str] + return BaseConnection(self) def init_db(self, commit_fkey: bool = True) -> 'BaseConnection': if not isinstance(commit_fkey, bool): raise_auto_arg_type_error("commit_fkey") - connection = self._register_connection() + connection = self.create_connection() connection.open() connection.enforce_foreign_key_constraints(commit_fkey) return connection @@ -194,55 +113,29 @@ def init_db_auto_close(self): # don't bother to commit fkey constraint because close() will commit before connection closure connection = self.init_db(False) connection.close() - self._unregister_connection(connection.uid) + # SQLAlchemy automatically handles connection cleanup def close_db(self): - uids = [] - self.check_connection_registry_integrity() - for uid, connection in self._connections.items(): - connection.close() - uids.append(uid) - for uid in uids: - self._unregister_connection(uid) - if not len(self._connections) == 0: - raise RuntimeError("Not all connections were closed properly.") - self._connections = dict() - - # Dispose of the SQLAlchemy engine to clean up connection pool + """Close the database and dispose of the SQLAlchemy engine to clean up connection pool.""" if hasattr(self, '_engine') and self._engine is not None: self._engine.dispose() -class MemDB(Database): - def __init__(self, options: dict[str, Any] = {}): - super().__init__(options=options) - if not self.in_memory: - raise ValueError("MemDB can only be used for in-memory databases.") +# class MemDB(Database): +# def __init__(self, options: dict[str, Any] = {}): +# super().__init__(options=options) +# if not self.in_memory: +# raise ValueError("MemDB can only be used for in-memory databases.") - def check_connection_registry_integrity(self, skip_registry_type_check: bool = False): - """ - Check the integrity of the connection registry. - Implements early raise if more than one connection is found, - since in-memory databases can only have one connection. - """ - if not isinstance(skip_registry_type_check, bool): - raise_auto_arg_type_error("skip_registry_type_check") - if not skip_registry_type_check: - if not isinstance(self._connections, dict): - raise TypeError(f"Expected connection registry to be a dictionary but it was {type(self._connections).__name__}") - if (num_connections := len(self._connections)) > 1: - raise ValueError(f"In-memory databases can only have one active connection Found {num_connections}.") - return super().check_connection_registry_integrity(skip_registry_type_check=True) # Registry type already checked +# def init_db_auto_close(self): +# raise ValueError("In-memory databases cease to exist upon closure.") - def init_db_auto_close(self): - raise ValueError("In-memory databases cease to exist upon closure.") - -class FileDB(Database): - def __init__(self, options: dict[str, Any] = {}): - super().__init__(options=options) - if not self.is_file: - raise ValueError("FileDB can only be used for file-based databases.") +# class FileDB(Database): +# def __init__(self, options: dict[str, Any] = {}): +# super().__init__(options=options) +# if not self.is_file: +# raise ValueError("FileDB can only be used for file-based databases.") class BaseConnection: @@ -250,7 +143,6 @@ def __init__(self, db: Database): if not isinstance(db, Database): raise_auto_arg_type_error("db") self._db: Database = db - self._id = str(uuid.uuid4()) self._connection: SQLAConnection | None = None @property @@ -273,16 +165,6 @@ def db_is_file(self): def db_fkey_constraint(self): return self._db._fkey_constraint - @property - def uid(self): - # TODO: Hash together the uuid, db_name, and possibly also the location in memory to ensure uniqueness? - return self._id - - @property - def registered(self): - self._db.check_connection_integrity(self) # raises error if not registered - return True - @property def connection(self): if self._connection is None: @@ -340,6 +222,7 @@ def _execute(self, query: str, parameters: tuple = ()): def execute(self, query: str, parameters: tuple | None = None): """ Execute a SQL query with the given parameters using SQLAlchemy text() construct. + Returns the SQLAlchemy Result object for chaining fetchone/fetchall operations. """ # Validate query and parameters if not isinstance(query, str): @@ -354,73 +237,38 @@ def execute(self, query: str, parameters: tuple | None = None): # If `params` points to an object that isn't a tuple or None (per previous condition), raise a TypeError elif not isinstance(params, tuple): raise_auto_arg_type_error("parameters") - # Execute query with `params` - result = self._execute(query, params) - return result + # Execute query with `params` and store the result for potential fetching + self._last_result = self._execute(query, params) + return self._last_result def commit(self): self.connection.commit() def fetchone(self): - # For SQLAlchemy, we need to get the last result from the connection - # This is a simplified approach - in practice you'd want to store the result object - raise NotImplementedError("fetchone() needs to be called on the result object returned by execute()") + """Fetch one row from the last executed query result.""" + if not hasattr(self, '_last_result') or self._last_result is None: + raise RuntimeError("No query has been executed. Call execute() first.") + return self._last_result.fetchone() def fetchall(self): - # For SQLAlchemy, we need to get the last result from the connection - # This is a simplified approach - in practice you'd want to store the result object - raise NotImplementedError("fetchall() needs to be called on the result object returned by execute()") + """Fetch all rows from the last executed query result.""" + if not hasattr(self, '_last_result') or self._last_result is None: + raise RuntimeError("No query has been executed. Call execute() first.") + return self._last_result.fetchall() - def _validate_sql_identifier(self, identifier: str, identifier_type: str = "identifier") -> str: + def _validate_identifier(self, identifier: str, identifier_type: str = "identifier") -> str: """ - Validate and sanitize SQL identifiers (table names, column names) to prevent SQL injection. - Args: - identifier: The identifier to validate - identifier_type: Type of identifier for error messages (e.g., "table name", "column name") - Returns: - The validated identifier - Raises: - ValueError: If the identifier is invalid or potentially dangerous + Basic validation for identifiers. SQLAlchemy handles SQL injection protection. """ if not isinstance(identifier, str): raise TypeError(f"SQL {identifier_type} must be a string, got {type(identifier).__name__}") if len(identifier) == 0: raise ValueError(f"SQL {identifier_type} cannot be empty") - # Check for valid identifier pattern: starts with letter/underscore, contains only alphanumeric/underscore + # Basic validation - SQLAlchemy will handle the rest if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier): raise ValueError(f"Invalid SQL {identifier_type}: '{identifier}'. Must start with letter or underscore and contain only letters, numbers, and underscores.") - # Check against SQLite reserved words (common ones that could cause issues) - reserved_words = { - 'abort', 'action', 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', - 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', 'cascade', 'case', - 'cast', 'check', 'collate', 'column', 'commit', 'conflict', 'constraint', 'create', - 'cross', 'current', 'current_date', 'current_time', 'current_timestamp', 'database', - 'default', 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', 'do', - 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', 'exists', 'explain', - 'fail', 'filter', 'following', 'for', 'foreign', 'from', 'full', 'glob', 'group', - 'having', 'if', 'ignore', 'immediate', 'in', 'index', 'indexed', 'initially', 'inner', - 'insert', 'instead', 'intersect', 'into', 'is', 'isnull', 'join', 'key', 'left', - 'like', 'limit', 'match', 'natural', 'no', 'not', 'notnull', 'null', 'of', 'offset', - 'on', 'or', 'order', 'outer', 'over', 'partition', 'plan', 'pragma', 'preceding', - 'primary', 'query', 'raise', 'range', 'recursive', 'references', 'regexp', 'reindex', - 'release', 'rename', 'replace', 'restrict', 'right', 'rollback', 'row', 'rows', - 'savepoint', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to', 'transaction', - 'trigger', 'unbounded', 'union', 'unique', 'update', 'using', 'vacuum', 'values', - 'view', 'virtual', 'when', 'where', 'window', 'with', 'without' - } - if identifier.lower() in reserved_words: - raise ValueError(f"SQL {identifier_type} '{identifier}' is a reserved word and cannot be used") return identifier - def _escape_sql_identifier(self, identifier: str) -> str: - """ - Escape SQL identifier by wrapping in double quotes and escaping any internal quotes. - This should only be used after validation. - """ - # Escape any double quotes in the identifier by doubling them - escaped = identifier.replace('"', '""') - return f'"{escaped}"' - def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_table: bool = True, raise_if_exists: bool = True): # Validate table_name argument if not isinstance(table_name, str): @@ -439,20 +287,25 @@ def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_tab and isinstance(col[1], str) for col in columns)): raise_auto_arg_type_error("columns") - # Validate and sanitize table name - validated_table_name = self._validate_sql_identifier(table_name, "table name") - escaped_table_name = self._escape_sql_identifier(validated_table_name) - # Check if table already exists using parameterized query + + # Validate table name using basic validation + validated_table_name = self._validate_identifier(table_name, "table name") + + # Check if table already exists using SQLAlchemy parameterized query if raise_if_exists: - result = self.connection.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name"), {"table_name": validated_table_name}) + result = self.connection.execute( + text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name"), + {"table_name": validated_table_name} + ) if result.fetchone() is not None: raise ValueError(f"Table '{validated_table_name}' already exists.") - # Validate and construct columns portion of query - validated_columns = [] + + # Validate columns and build column definitions + column_definitions = [] for col_name, col_type in columns: # Validate column name - validated_col_name = self._validate_sql_identifier(col_name, "column name") - escaped_col_name = self._escape_sql_identifier(validated_col_name) + validated_col_name = self._validate_identifier(col_name, "column name") + # Validate column type - allow only safe, known SQLite types allowed_types = { 'TEXT', 'INTEGER', 'REAL', 'BLOB', 'NUMERIC', @@ -461,21 +314,32 @@ def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_tab 'BOOLEAN', 'DECIMAL', 'DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'SMALLINT', 'TINYINT' } + # Allow type specifications with length/precision (e.g., VARCHAR(50), DECIMAL(10,2)) base_type = re.match(r'^([A-Z]+)', col_type.upper()) if not base_type or base_type.group(1) not in allowed_types: raise ValueError(f"Unsupported column type: '{col_type}'. Must be one of: {', '.join(sorted(allowed_types))}") + # Basic validation for type specification format if not re.match(r'^[A-Z]+(\([0-9,\s]+\))?$', col_type.upper()): raise ValueError(f"Invalid column type format: '{col_type}'") - validated_columns.append(f"{escaped_col_name} {col_type.upper()}") - columns_qstr = ",\n ".join(validated_columns) - # Assemble full query with escaped identifiers + + # Use double quotes for identifier escaping (SQLite standard) + escaped_col_name = f'"{validated_col_name}"' + column_definitions.append(f"{escaped_col_name} {col_type.upper()}") + + columns_qstr = ",\n ".join(column_definitions) + + # Build CREATE TABLE statement with proper identifier escaping temp_keyword = " TEMPORARY" if temp_table else "" + escaped_table_name = f'"{validated_table_name}"' + query = f"""CREATE{temp_keyword} TABLE IF NOT EXISTS {escaped_table_name} ( id INTEGER PRIMARY KEY AUTOINCREMENT, {columns_qstr} - );""" + )""" + + # Execute using SQLAlchemy's text() construct self.connection.execute(text(query)) def drop_table(self, table_name: str, raise_if_not_exists: bool = False): @@ -486,16 +350,23 @@ def drop_table(self, table_name: str, raise_if_not_exists: bool = False): raise ValueError("'table_name' argument of drop_table cannot be an empty string!") if not isinstance(raise_if_not_exists, bool): raise_auto_arg_type_error("raise_if_not_exists") - # Validate and sanitize table name - validated_table_name = self._validate_sql_identifier(table_name, "table name") - escaped_table_name = self._escape_sql_identifier(validated_table_name) - # Check if table exists using parameterized query + + # Validate table name using basic validation + validated_table_name = self._validate_identifier(table_name, "table name") + + # Check if table exists using SQLAlchemy parameterized query if raise_if_not_exists: - result = self.connection.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name"), {"table_name": validated_table_name}) + result = self.connection.execute( + text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name"), + {"table_name": validated_table_name} + ) if result.fetchone() is None: raise ValueError(f"Table '{validated_table_name}' does not exist.") - # Execute DROP statement with escaped identifier - self.connection.execute(text(f"DROP TABLE IF EXISTS {escaped_table_name};")) + + # Execute DROP TABLE query with proper identifier escaping + escaped_table_name = f'"{validated_table_name}"' + query = f"DROP TABLE IF EXISTS {escaped_table_name}" + self.connection.execute(text(query)) # def read(self): # pass diff --git a/stat_log_db/tests/test_sql_injection.py b/stat_log_db/tests/test_sql_injection.py index 447ba96..bf9bd37 100644 --- a/stat_log_db/tests/test_sql_injection.py +++ b/stat_log_db/tests/test_sql_injection.py @@ -6,7 +6,7 @@ import sys from pathlib import Path -from stat_log_db.db import MemDB +from stat_log_db.db import Database # Add the src directory to the path to import the module @@ -17,7 +17,7 @@ @pytest.fixture def mem_db(): """Create a test in-memory database and clean up after tests.""" - sl_db = MemDB({ + sl_db = Database({ "is_mem": True, "fkey_constraint": True }) From 6a2d5575f8d3f99f81b4b1a4101755b371f88b69 Mon Sep 17 00:00:00 2001 From: bjthres1 Date: Tue, 9 Sep 2025 19:24:25 -0400 Subject: [PATCH 3/3] feat(db): enhance identifier and col type validation Added comprehensive validation for SQL identifiers and SQLite column types, ensuring safety against reserved words and format issues. Introduced helper methods for building column definitions and mapping SQLite types to SQLAlchemy types, improving the robustness of table creation functionality. --- stat_log_db/src/stat_log_db/db.py | 287 ++++++++++++++++++++++++------ 1 file changed, 237 insertions(+), 50 deletions(-) diff --git a/stat_log_db/src/stat_log_db/db.py b/stat_log_db/src/stat_log_db/db.py index d14e8c1..445eb25 100644 --- a/stat_log_db/src/stat_log_db/db.py +++ b/stat_log_db/src/stat_log_db/db.py @@ -4,7 +4,8 @@ import uuid from typing import Any -from sqlalchemy import create_engine, text, Engine, Connection as SQLAConnection +from sqlalchemy import create_engine, text, Engine, Connection as SQLAConnection, MetaData, Table, Column +from sqlalchemy import Integer, String, Text, Boolean, Float, DateTime, LargeBinary from sqlalchemy.engine import make_url from sqlalchemy.pool import StaticPool, QueuePool @@ -258,89 +259,275 @@ def fetchall(self): def _validate_identifier(self, identifier: str, identifier_type: str = "identifier") -> str: """ - Basic validation for identifiers. SQLAlchemy handles SQL injection protection. + Validate SQL identifiers with reserved word checking and basic format validation. + SQLAlchemy handles parameterization, but we still validate identifier safety. """ if not isinstance(identifier, str): raise TypeError(f"SQL {identifier_type} must be a string, got {type(identifier).__name__}") if len(identifier) == 0: raise ValueError(f"SQL {identifier_type} cannot be empty") - # Basic validation - SQLAlchemy will handle the rest + + # Basic format validation if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier): raise ValueError(f"Invalid SQL {identifier_type}: '{identifier}'. Must start with letter or underscore and contain only letters, numbers, and underscores.") + + # Check against SQL reserved words + sql_reserved_words = { + 'ABORT', 'ACTION', 'ADD', 'AFTER', 'ALL', 'ALTER', 'ANALYZE', 'AND', 'AS', 'ASC', + 'ATTACH', 'AUTOINCREMENT', 'BEFORE', 'BEGIN', 'BETWEEN', 'BY', 'CASCADE', 'CASE', + 'CAST', 'CHECK', 'COLLATE', 'COLUMN', 'COMMIT', 'CONFLICT', 'CONSTRAINT', 'CREATE', + 'CROSS', 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', 'DATABASE', 'DEFAULT', + 'DEFERRABLE', 'DEFERRED', 'DELETE', 'DESC', 'DETACH', 'DISTINCT', 'DROP', 'EACH', + 'ELSE', 'END', 'ESCAPE', 'EXCEPT', 'EXCLUSIVE', 'EXISTS', 'EXPLAIN', 'FAIL', 'FOR', + 'FOREIGN', 'FROM', 'FULL', 'GLOB', 'GROUP', 'HAVING', 'IF', 'IGNORE', 'IMMEDIATE', + 'IN', 'INDEX', 'INDEXED', 'INITIALLY', 'INNER', 'INSERT', 'INSTEAD', 'INTERSECT', + 'INTO', 'IS', 'ISNULL', 'JOIN', 'KEY', 'LEFT', 'LIKE', 'LIMIT', 'MATCH', 'NATURAL', + 'NO', 'NOT', 'NOTNULL', 'NULL', 'OF', 'OFFSET', 'ON', 'OR', 'ORDER', 'OUTER', 'PLAN', + 'PRAGMA', 'PRIMARY', 'QUERY', 'RAISE', 'RECURSIVE', 'REFERENCES', 'REGEXP', 'REINDEX', + 'RELEASE', 'RENAME', 'REPLACE', 'RESTRICT', 'RIGHT', 'ROLLBACK', 'ROW', 'SAVEPOINT', + 'SELECT', 'SET', 'TABLE', 'TEMP', 'TEMPORARY', 'THEN', 'TO', 'TRANSACTION', 'TRIGGER', + 'UNION', 'UNIQUE', 'UPDATE', 'USING', 'VACUUM', 'VALUES', 'VIEW', 'VIRTUAL', 'WHEN', + 'WHERE', 'WITH', 'WITHOUT' + } + + if identifier.upper() in sql_reserved_words: + raise ValueError(f"SQL {identifier_type} '{identifier}' is a reserved word") + return identifier + def _validate_column_type(self, col_type: str) -> str: + """ + Validate SQLite column type specification. + Returns the normalized (uppercase) column type. + """ + if not isinstance(col_type, str): + raise TypeError(f"Column type must be a string, got {type(col_type).__name__}") + + # Normalize to uppercase + normalized_type = col_type.upper().strip() + + # Define allowed SQLite types + allowed_types = { + 'TEXT', 'INTEGER', 'REAL', 'BLOB', 'NUMERIC', + 'VARCHAR', 'CHAR', 'NVARCHAR', 'NCHAR', + 'CLOB', 'DATE', 'DATETIME', 'TIMESTAMP', + 'BOOLEAN', 'DECIMAL', 'DOUBLE', 'FLOAT', + 'INT', 'BIGINT', 'SMALLINT', 'TINYINT' + } + + # Extract base type (before any parentheses) + base_type_match = re.match(r'^([A-Z]+)', normalized_type) + if not base_type_match: + raise ValueError(f"Invalid column type format: '{col_type}'") + + base_type = base_type_match.group(1) + if base_type not in allowed_types: + raise ValueError(f"Unsupported column type: '{col_type}'. Must be one of: {', '.join(sorted(allowed_types))}") + + # Validate full type specification format (allowing precision/length specifiers) + if not re.match(r'^[A-Z]+(\([0-9,\s]+\))?$', normalized_type): + raise ValueError(f"Invalid column type format: '{col_type}'. Use format like 'VARCHAR(50)' or 'DECIMAL(10,2)'") + + return normalized_type + + def _build_column_definition(self, col_name: str, col_type: str) -> str: + """ + Build a column definition string with proper validation and escaping. + Returns formatted column definition for CREATE TABLE statement. + """ + validated_col_name = self._validate_identifier(col_name, "column name") + validated_col_type = self._validate_column_type(col_type) + + # Use SQLite standard double-quote escaping for identifiers + escaped_col_name = f'"{validated_col_name}"' + + return f"{escaped_col_name} {validated_col_type}" + def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_table: bool = True, raise_if_exists: bool = True): - # Validate table_name argument + """ + Create a new table using SQLAlchemy with proper validation and escaping. + + Args: + table_name: Name of the table to create + columns: List of (column_name, column_type) tuples + temp_table: Whether to create a temporary table + raise_if_exists: Whether to raise an error if table already exists + """ + # Validate arguments if not isinstance(table_name, str): raise_auto_arg_type_error("table_name") if len(table_name) == 0: - raise ValueError("'table_name' argument of create_table cannot be an empty string!") - # Validate temp_table argument + raise ValueError("'table_name' argument cannot be an empty string") if not isinstance(temp_table, bool): raise_auto_arg_type_error("temp_table") if not isinstance(raise_if_exists, bool): raise_auto_arg_type_error("raise_if_exists") - # Validate columns argument - if (not isinstance(columns, list)) or (not all( + if not isinstance(columns, list) or not all( isinstance(col, tuple) and len(col) == 2 - and isinstance(col[0], str) - and isinstance(col[1], str) - for col in columns)): + and isinstance(col[0], str) and isinstance(col[1], str) + for col in columns + ): raise_auto_arg_type_error("columns") - # Validate table name using basic validation + # Validate and normalize table name validated_table_name = self._validate_identifier(table_name, "table name") + escaped_table_name = f'"{validated_table_name}"' # Check if table already exists using SQLAlchemy parameterized query if raise_if_exists: - result = self.connection.execute( - text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name"), - {"table_name": validated_table_name} - ) + check_query = text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name") + result = self.connection.execute(check_query, {"table_name": validated_table_name}) if result.fetchone() is not None: - raise ValueError(f"Table '{validated_table_name}' already exists.") + raise ValueError(f"Table '{validated_table_name}' already exists") - # Validate columns and build column definitions + # Build column definitions using helper method column_definitions = [] for col_name, col_type in columns: - # Validate column name - validated_col_name = self._validate_identifier(col_name, "column name") - - # Validate column type - allow only safe, known SQLite types - allowed_types = { - 'TEXT', 'INTEGER', 'REAL', 'BLOB', 'NUMERIC', - 'VARCHAR', 'CHAR', 'NVARCHAR', 'NCHAR', - 'CLOB', 'DATE', 'DATETIME', 'TIMESTAMP', - 'BOOLEAN', 'DECIMAL', 'DOUBLE', 'FLOAT', - 'INT', 'BIGINT', 'SMALLINT', 'TINYINT' - } - - # Allow type specifications with length/precision (e.g., VARCHAR(50), DECIMAL(10,2)) - base_type = re.match(r'^([A-Z]+)', col_type.upper()) - if not base_type or base_type.group(1) not in allowed_types: - raise ValueError(f"Unsupported column type: '{col_type}'. Must be one of: {', '.join(sorted(allowed_types))}") - - # Basic validation for type specification format - if not re.match(r'^[A-Z]+(\([0-9,\s]+\))?$', col_type.upper()): - raise ValueError(f"Invalid column type format: '{col_type}'") - - # Use double quotes for identifier escaping (SQLite standard) - escaped_col_name = f'"{validated_col_name}"' - column_definitions.append(f"{escaped_col_name} {col_type.upper()}") + column_def = self._build_column_definition(col_name, col_type) + column_definitions.append(column_def) - columns_qstr = ",\n ".join(column_definitions) + # Format column definitions for query + columns_clause = ",\n ".join(column_definitions) - # Build CREATE TABLE statement with proper identifier escaping + # Build CREATE TABLE statement temp_keyword = " TEMPORARY" if temp_table else "" - escaped_table_name = f'"{validated_table_name}"' - - query = f"""CREATE{temp_keyword} TABLE IF NOT EXISTS {escaped_table_name} ( + create_query = f"""CREATE{temp_keyword} TABLE IF NOT EXISTS {escaped_table_name} ( id INTEGER PRIMARY KEY AUTOINCREMENT, - {columns_qstr} + {columns_clause} )""" - # Execute using SQLAlchemy's text() construct - self.connection.execute(text(query)) + # Execute using SQLAlchemy's text() construct for safe execution + self.connection.execute(text(create_query)) + + def _map_sqlite_type_to_sqlalchemy(self, sqlite_type: str): + """ + Map SQLite type strings to SQLAlchemy column types. + Returns appropriate SQLAlchemy column type class. + """ + # Normalize the type string + normalized_type = sqlite_type.upper().strip() + base_type_match = re.match(r'^([A-Z]+)', normalized_type) + if not base_type_match: + raise ValueError(f"Invalid SQLite type format: '{sqlite_type}'") + base_type = base_type_match.group(1) + + # SQLite to SQLAlchemy type mapping + type_mapping = { + 'TEXT': Text, + 'VARCHAR': String, + 'CHAR': String, + 'NVARCHAR': String, + 'NCHAR': String, + 'CLOB': Text, + 'INTEGER': Integer, + 'INT': Integer, + 'BIGINT': Integer, + 'SMALLINT': Integer, + 'TINYINT': Integer, + 'REAL': Float, + 'DOUBLE': Float, + 'FLOAT': Float, + 'NUMERIC': Float, + 'DECIMAL': Float, + 'BOOLEAN': Boolean, + 'DATE': DateTime, + 'DATETIME': DateTime, + 'TIMESTAMP': DateTime, + 'BLOB': LargeBinary, + } + + if base_type not in type_mapping: + raise ValueError(f"Cannot map SQLite type '{sqlite_type}' to SQLAlchemy type") + + sqlalchemy_type = type_mapping[base_type] + + # Handle length specifications for string types + if base_type in ('VARCHAR', 'CHAR', 'NVARCHAR', 'NCHAR') and '(' in normalized_type: + length_match = re.search(r'\((\d+)\)', normalized_type) + if length_match: + length = int(length_match.group(1)) + return sqlalchemy_type(length) + + return sqlalchemy_type() + + def create_table_with_sqlalchemy_ddl(self, table_name: str, columns: list[tuple[str, str]], temp_table: bool = True, raise_if_exists: bool = True): + """ + Create a table using SQLAlchemy's Table and MetaData objects for enhanced DDL capabilities. + This provides better type safety and integration with SQLAlchemy's ORM features. + + Args: + table_name: Name of the table to create + columns: List of (column_name, column_type) tuples + temp_table: Whether to create a temporary table + raise_if_exists: Whether to raise an error if table already exists + """ + # Validate arguments (reuse validation from create_table) + if not isinstance(table_name, str) or len(table_name) == 0: + raise ValueError("table_name must be a non-empty string") + if not isinstance(temp_table, bool): + raise_auto_arg_type_error("temp_table") + if not isinstance(raise_if_exists, bool): + raise_auto_arg_type_error("raise_if_exists") + if not isinstance(columns, list) or not all( + isinstance(col, tuple) and len(col) == 2 + and isinstance(col[0], str) and isinstance(col[1], str) + for col in columns + ): + raise_auto_arg_type_error("columns") + + # Validate table name + validated_table_name = self._validate_identifier(table_name, "table name") + + # Check if table exists + if raise_if_exists: + check_query = text("SELECT name FROM sqlite_master WHERE type='table' AND name=:table_name") + result = self.connection.execute(check_query, {"table_name": validated_table_name}) + if result.fetchone() is not None: + raise ValueError(f"Table '{validated_table_name}' already exists") + + # Create MetaData object + metadata = MetaData() + + # Build SQLAlchemy Column objects + sqlalchemy_columns = [ + Column('id', Integer, primary_key=True, autoincrement=True) + ] + + for col_name, col_type in columns: + validated_col_name = self._validate_identifier(col_name, "column name") + validated_col_type = self._validate_column_type(col_type) + + # Map to SQLAlchemy type + sqlalchemy_type = self._map_sqlite_type_to_sqlalchemy(validated_col_type) + sqlalchemy_columns.append(Column(validated_col_name, sqlalchemy_type)) + + # Create Table object + if temp_table: + # For temporary tables, fall back to raw SQL since SQLAlchemy doesn't have direct support + temp_keyword = " TEMPORARY" + column_defs = [] + for col in sqlalchemy_columns: + if col.name == 'id': + column_defs.append(f'"{col.name}" {col.type.compile(self._db.engine.dialect)} PRIMARY KEY AUTOINCREMENT') + else: + column_defs.append(f'"{col.name}" {col.type.compile(self._db.engine.dialect)}') + + columns_clause = ",\n ".join(column_defs) + escaped_table_name = f'"{validated_table_name}"' + create_query = f"""CREATE{temp_keyword} TABLE IF NOT EXISTS {escaped_table_name} ( + {columns_clause} + )""" + self.connection.execute(text(create_query)) + else: + # For regular tables, use SQLAlchemy's DDL capabilities + table = Table( + validated_table_name, + metadata, + *sqlalchemy_columns + ) + + # Create the table using SQLAlchemy DDL + metadata.create_all(self._db.engine, tables=[table], checkfirst=not raise_if_exists) def drop_table(self, table_name: str, raise_if_not_exists: bool = False): # Validate table_name argument