11# from abc import ABC, abstractmethod
2+ import re
23import sqlite3
34import uuid
45from typing import Any
@@ -317,42 +318,132 @@ def fetchone(self):
317318 def fetchall (self ):
318319 return self .cursor .fetchall ()
319320
321+ def _validate_sql_identifier (self , identifier : str , identifier_type : str = "identifier" ) -> str :
322+ """
323+ Validate and sanitize SQL identifiers (table names, column names) to prevent SQL injection.
324+
325+ Args:
326+ identifier: The identifier to validate
327+ identifier_type: Type of identifier for error messages (e.g., "table name", "column name")
328+
329+ Returns:
330+ The validated identifier
331+
332+ Raises:
333+ ValueError: If the identifier is invalid or potentially dangerous
334+ """
335+ if not isinstance (identifier , str ):
336+ raise TypeError (f"SQL { identifier_type } must be a string, got { type (identifier ).__name__ } " )
337+
338+ if len (identifier ) == 0 :
339+ raise ValueError (f"SQL { identifier_type } cannot be empty" )
340+
341+ # Check for valid identifier pattern: starts with letter/underscore, contains only alphanumeric/underscore
342+ if not re .match (r'^[a-zA-Z_][a-zA-Z0-9_]*$' , identifier ):
343+ raise ValueError (f"Invalid SQL { identifier_type } : '{ identifier } '. Must start with letter or underscore and contain only letters, numbers, and underscores." )
344+
345+ # Check against SQLite reserved words (common ones that could cause issues)
346+ reserved_words = {
347+ 'abort' , 'action' , 'add' , 'after' , 'all' , 'alter' , 'analyze' , 'and' , 'as' , 'asc' ,
348+ 'attach' , 'autoincrement' , 'before' , 'begin' , 'between' , 'by' , 'cascade' , 'case' ,
349+ 'cast' , 'check' , 'collate' , 'column' , 'commit' , 'conflict' , 'constraint' , 'create' ,
350+ 'cross' , 'current' , 'current_date' , 'current_time' , 'current_timestamp' , 'database' ,
351+ 'default' , 'deferrable' , 'deferred' , 'delete' , 'desc' , 'detach' , 'distinct' , 'do' ,
352+ 'drop' , 'each' , 'else' , 'end' , 'escape' , 'except' , 'exclusive' , 'exists' , 'explain' ,
353+ 'fail' , 'filter' , 'following' , 'for' , 'foreign' , 'from' , 'full' , 'glob' , 'group' ,
354+ 'having' , 'if' , 'ignore' , 'immediate' , 'in' , 'index' , 'indexed' , 'initially' , 'inner' ,
355+ 'insert' , 'instead' , 'intersect' , 'into' , 'is' , 'isnull' , 'join' , 'key' , 'left' ,
356+ 'like' , 'limit' , 'match' , 'natural' , 'no' , 'not' , 'notnull' , 'null' , 'of' , 'offset' ,
357+ 'on' , 'or' , 'order' , 'outer' , 'over' , 'partition' , 'plan' , 'pragma' , 'preceding' ,
358+ 'primary' , 'query' , 'raise' , 'range' , 'recursive' , 'references' , 'regexp' , 'reindex' ,
359+ 'release' , 'rename' , 'replace' , 'restrict' , 'right' , 'rollback' , 'row' , 'rows' ,
360+ 'savepoint' , 'select' , 'set' , 'table' , 'temp' , 'temporary' , 'then' , 'to' , 'transaction' ,
361+ 'trigger' , 'unbounded' , 'union' , 'unique' , 'update' , 'using' , 'vacuum' , 'values' ,
362+ 'view' , 'virtual' , 'when' , 'where' , 'window' , 'with' , 'without'
363+ }
364+
365+ if identifier .lower () in reserved_words :
366+ raise ValueError (f"SQL { identifier_type } '{ identifier } ' is a reserved word and cannot be used" )
367+
368+ return identifier
369+
370+ def _escape_sql_identifier (self , identifier : str ) -> str :
371+ """
372+ Escape SQL identifier by wrapping in double quotes and escaping any internal quotes.
373+ This should only be used after validation.
374+ """
375+ # Escape any double quotes in the identifier by doubling them
376+ escaped = identifier .replace ('"' , '""' )
377+ return f'"{ escaped } "'
378+
320379 def create_table (self , table_name : str , columns : list [tuple [str , str ]], temp_table : bool = True , raise_if_exists : bool = True ):
321380 # Validate table_name argument
322381 if not isinstance (table_name , str ):
323382 raise_auto_arg_type_error ("table_name" )
324383 if len (table_name ) == 0 :
325384 raise ValueError (f"'table_name' argument of create_table cannot be an empty string!" )
385+
386+ # Validate and sanitize table name
387+ validated_table_name = self ._validate_sql_identifier (table_name , "table name" )
388+ escaped_table_name = self ._escape_sql_identifier (validated_table_name )
389+
326390 if not isinstance (raise_if_exists , bool ):
327391 raise_auto_arg_type_error ("raise_if_exists" )
328- # Check if table already exists
392+
393+ # Check if table already exists using parameterized query
329394 if raise_if_exists :
330- self .cursor .execute ("SELECT name FROM sqlite_master WHERE type='table' AND name=?;" , (table_name ,))
395+ self .cursor .execute ("SELECT name FROM sqlite_master WHERE type='table' AND name=?;" , (validated_table_name ,))
331396 if self .cursor .fetchone () is not None :
332- raise ValueError (f"Table '{ table_name } ' already exists." )
397+ raise ValueError (f"Table '{ validated_table_name } ' already exists." )
398+
333399 # Validate temp_table argument
334400 if not isinstance (temp_table , bool ):
335401 raise_auto_arg_type_error ("temp_table" )
402+
336403 # Validate columns argument
337404 if (not isinstance (columns , list )) or (not all (
338405 isinstance (col , tuple ) and len (col ) == 2
339406 and isinstance (col [0 ], str )
340407 and isinstance (col [1 ], str )
341408 for col in columns )):
342409 raise_auto_arg_type_error ("columns" )
343- # Construct columns portion of query
344- # TODO: construct parameters for columns rather than f-string to prevent SQL injection
345- columns_qstr = ""
346- for col in columns :
347- columns_qstr += f"{ col [0 ]} { col [1 ]} ,\n "
348- columns_qstr = columns_qstr .rstrip (",\n " ) # Remove trailing comma and newline
349- # Assemble full query
350- query = f"""--sql
351- CREATE{ " TEMPORARY" if temp_table else "" } TABLE IF NOT EXISTS '{ table_name } ' (
410+
411+ # Validate and construct columns portion of query
412+ validated_columns = []
413+ for col_name , col_type in columns :
414+ # Validate column name
415+ validated_col_name = self ._validate_sql_identifier (col_name , "column name" )
416+ escaped_col_name = self ._escape_sql_identifier (validated_col_name )
417+
418+ # Validate column type - allow only safe, known SQLite types
419+ allowed_types = {
420+ 'TEXT' , 'INTEGER' , 'REAL' , 'BLOB' , 'NUMERIC' ,
421+ 'VARCHAR' , 'CHAR' , 'NVARCHAR' , 'NCHAR' ,
422+ 'CLOB' , 'DATE' , 'DATETIME' , 'TIMESTAMP' ,
423+ 'BOOLEAN' , 'DECIMAL' , 'DOUBLE' , 'FLOAT' ,
424+ 'INT' , 'BIGINT' , 'SMALLINT' , 'TINYINT'
425+ }
426+
427+ # Allow type specifications with length/precision (e.g., VARCHAR(50), DECIMAL(10,2))
428+ base_type = re .match (r'^([A-Z]+)' , col_type .upper ())
429+ if not base_type or base_type .group (1 ) not in allowed_types :
430+ raise ValueError (f"Unsupported column type: '{ col_type } '. Must be one of: { ', ' .join (sorted (allowed_types ))} " )
431+
432+ # Basic validation for type specification format
433+ if not re .match (r'^[A-Z]+(\([0-9,\s]+\))?$' , col_type .upper ()):
434+ raise ValueError (f"Invalid column type format: '{ col_type } '" )
435+
436+ validated_columns .append (f"{ escaped_col_name } { col_type .upper ()} " )
437+
438+ columns_qstr = ",\n " .join (validated_columns )
439+
440+ # Assemble full query with escaped identifiers
441+ temp_keyword = " TEMPORARY" if temp_table else ""
442+ query = f"""CREATE{ temp_keyword } TABLE IF NOT EXISTS { escaped_table_name } (
352443 id INTEGER PRIMARY KEY AUTOINCREMENT,
353444 { columns_qstr }
354- );
355- """
445+ );"""
446+
356447 self .execute (query )
357448
358449 def drop_table (self , table_name : str , raise_if_not_exists : bool = False ):
@@ -361,13 +452,22 @@ def drop_table(self, table_name: str, raise_if_not_exists: bool = False):
361452 raise_auto_arg_type_error ("table_name" )
362453 if len (table_name ) == 0 :
363454 raise ValueError (f"'table_name' argument of drop_table cannot be an empty string!" )
455+
456+ # Validate and sanitize table name
457+ validated_table_name = self ._validate_sql_identifier (table_name , "table name" )
458+ escaped_table_name = self ._escape_sql_identifier (validated_table_name )
459+
364460 if not isinstance (raise_if_not_exists , bool ):
365461 raise_auto_arg_type_error ("raise_if_not_exists" )
462+
463+ # Check if table exists using parameterized query
366464 if raise_if_not_exists :
367- self .cursor .execute (f "SELECT name FROM sqlite_master WHERE type='table' AND name=' { table_name } ';" )
465+ self .cursor .execute ("SELECT name FROM sqlite_master WHERE type='table' AND name=?;" , ( validated_table_name ,) )
368466 if self .cursor .fetchone () is None :
369- raise ValueError (f"Table '{ table_name } ' does not exist." )
370- self .cursor .execute (f"DROP TABLE IF EXISTS '{ table_name } ';" )
467+ raise ValueError (f"Table '{ validated_table_name } ' does not exist." )
468+
469+ # Execute DROP statement with escaped identifier
470+ self .cursor .execute (f"DROP TABLE IF EXISTS { escaped_table_name } ;" )
371471
372472 # def read(self):
373473
0 commit comments