Skip to content

Commit 4304de5

Browse files
committed
Implement basic sanitization, injection tests, improve git-bash lookup on windows
1 parent 0aacc4b commit 4304de5

File tree

3 files changed

+332
-18
lines changed

3 files changed

+332
-18
lines changed

stat_log_db/src/stat_log_db/db.py

Lines changed: 117 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# from abc import ABC, abstractmethod
2+
import re
23
import sqlite3
34
import uuid
45
from typing import Any
@@ -317,42 +318,132 @@ def fetchone(self):
317318
def fetchall(self):
318319
return self.cursor.fetchall()
319320

321+
def _validate_sql_identifier(self, identifier: str, identifier_type: str = "identifier") -> str:
322+
"""
323+
Validate and sanitize SQL identifiers (table names, column names) to prevent SQL injection.
324+
325+
Args:
326+
identifier: The identifier to validate
327+
identifier_type: Type of identifier for error messages (e.g., "table name", "column name")
328+
329+
Returns:
330+
The validated identifier
331+
332+
Raises:
333+
ValueError: If the identifier is invalid or potentially dangerous
334+
"""
335+
if not isinstance(identifier, str):
336+
raise TypeError(f"SQL {identifier_type} must be a string, got {type(identifier).__name__}")
337+
338+
if len(identifier) == 0:
339+
raise ValueError(f"SQL {identifier_type} cannot be empty")
340+
341+
# Check for valid identifier pattern: starts with letter/underscore, contains only alphanumeric/underscore
342+
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', identifier):
343+
raise ValueError(f"Invalid SQL {identifier_type}: '{identifier}'. Must start with letter or underscore and contain only letters, numbers, and underscores.")
344+
345+
# Check against SQLite reserved words (common ones that could cause issues)
346+
reserved_words = {
347+
'abort', 'action', 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc',
348+
'attach', 'autoincrement', 'before', 'begin', 'between', 'by', 'cascade', 'case',
349+
'cast', 'check', 'collate', 'column', 'commit', 'conflict', 'constraint', 'create',
350+
'cross', 'current', 'current_date', 'current_time', 'current_timestamp', 'database',
351+
'default', 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', 'do',
352+
'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', 'exists', 'explain',
353+
'fail', 'filter', 'following', 'for', 'foreign', 'from', 'full', 'glob', 'group',
354+
'having', 'if', 'ignore', 'immediate', 'in', 'index', 'indexed', 'initially', 'inner',
355+
'insert', 'instead', 'intersect', 'into', 'is', 'isnull', 'join', 'key', 'left',
356+
'like', 'limit', 'match', 'natural', 'no', 'not', 'notnull', 'null', 'of', 'offset',
357+
'on', 'or', 'order', 'outer', 'over', 'partition', 'plan', 'pragma', 'preceding',
358+
'primary', 'query', 'raise', 'range', 'recursive', 'references', 'regexp', 'reindex',
359+
'release', 'rename', 'replace', 'restrict', 'right', 'rollback', 'row', 'rows',
360+
'savepoint', 'select', 'set', 'table', 'temp', 'temporary', 'then', 'to', 'transaction',
361+
'trigger', 'unbounded', 'union', 'unique', 'update', 'using', 'vacuum', 'values',
362+
'view', 'virtual', 'when', 'where', 'window', 'with', 'without'
363+
}
364+
365+
if identifier.lower() in reserved_words:
366+
raise ValueError(f"SQL {identifier_type} '{identifier}' is a reserved word and cannot be used")
367+
368+
return identifier
369+
370+
def _escape_sql_identifier(self, identifier: str) -> str:
371+
"""
372+
Escape SQL identifier by wrapping in double quotes and escaping any internal quotes.
373+
This should only be used after validation.
374+
"""
375+
# Escape any double quotes in the identifier by doubling them
376+
escaped = identifier.replace('"', '""')
377+
return f'"{escaped}"'
378+
320379
def create_table(self, table_name: str, columns: list[tuple[str, str]], temp_table: bool = True, raise_if_exists: bool = True):
321380
# Validate table_name argument
322381
if not isinstance(table_name, str):
323382
raise_auto_arg_type_error("table_name")
324383
if len(table_name) == 0:
325384
raise ValueError(f"'table_name' argument of create_table cannot be an empty string!")
385+
386+
# Validate and sanitize table name
387+
validated_table_name = self._validate_sql_identifier(table_name, "table name")
388+
escaped_table_name = self._escape_sql_identifier(validated_table_name)
389+
326390
if not isinstance(raise_if_exists, bool):
327391
raise_auto_arg_type_error("raise_if_exists")
328-
# Check if table already exists
392+
393+
# Check if table already exists using parameterized query
329394
if raise_if_exists:
330-
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (table_name,))
395+
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,))
331396
if self.cursor.fetchone() is not None:
332-
raise ValueError(f"Table '{table_name}' already exists.")
397+
raise ValueError(f"Table '{validated_table_name}' already exists.")
398+
333399
# Validate temp_table argument
334400
if not isinstance(temp_table, bool):
335401
raise_auto_arg_type_error("temp_table")
402+
336403
# Validate columns argument
337404
if (not isinstance(columns, list)) or (not all(
338405
isinstance(col, tuple) and len(col) == 2
339406
and isinstance(col[0], str)
340407
and isinstance(col[1], str)
341408
for col in columns)):
342409
raise_auto_arg_type_error("columns")
343-
# Construct columns portion of query
344-
# TODO: construct parameters for columns rather than f-string to prevent SQL injection
345-
columns_qstr = ""
346-
for col in columns:
347-
columns_qstr += f"{col[0]} {col[1]},\n"
348-
columns_qstr = columns_qstr.rstrip(",\n") # Remove trailing comma and newline
349-
# Assemble full query
350-
query = f"""--sql
351-
CREATE{" TEMPORARY" if temp_table else ""} TABLE IF NOT EXISTS '{table_name}' (
410+
411+
# Validate and construct columns portion of query
412+
validated_columns = []
413+
for col_name, col_type in columns:
414+
# Validate column name
415+
validated_col_name = self._validate_sql_identifier(col_name, "column name")
416+
escaped_col_name = self._escape_sql_identifier(validated_col_name)
417+
418+
# Validate column type - allow only safe, known SQLite types
419+
allowed_types = {
420+
'TEXT', 'INTEGER', 'REAL', 'BLOB', 'NUMERIC',
421+
'VARCHAR', 'CHAR', 'NVARCHAR', 'NCHAR',
422+
'CLOB', 'DATE', 'DATETIME', 'TIMESTAMP',
423+
'BOOLEAN', 'DECIMAL', 'DOUBLE', 'FLOAT',
424+
'INT', 'BIGINT', 'SMALLINT', 'TINYINT'
425+
}
426+
427+
# Allow type specifications with length/precision (e.g., VARCHAR(50), DECIMAL(10,2))
428+
base_type = re.match(r'^([A-Z]+)', col_type.upper())
429+
if not base_type or base_type.group(1) not in allowed_types:
430+
raise ValueError(f"Unsupported column type: '{col_type}'. Must be one of: {', '.join(sorted(allowed_types))}")
431+
432+
# Basic validation for type specification format
433+
if not re.match(r'^[A-Z]+(\([0-9,\s]+\))?$', col_type.upper()):
434+
raise ValueError(f"Invalid column type format: '{col_type}'")
435+
436+
validated_columns.append(f"{escaped_col_name} {col_type.upper()}")
437+
438+
columns_qstr = ",\n ".join(validated_columns)
439+
440+
# Assemble full query with escaped identifiers
441+
temp_keyword = " TEMPORARY" if temp_table else ""
442+
query = f"""CREATE{temp_keyword} TABLE IF NOT EXISTS {escaped_table_name} (
352443
id INTEGER PRIMARY KEY AUTOINCREMENT,
353444
{columns_qstr}
354-
);
355-
"""
445+
);"""
446+
356447
self.execute(query)
357448

358449
def drop_table(self, table_name: str, raise_if_not_exists: bool = False):
@@ -361,13 +452,22 @@ def drop_table(self, table_name: str, raise_if_not_exists: bool = False):
361452
raise_auto_arg_type_error("table_name")
362453
if len(table_name) == 0:
363454
raise ValueError(f"'table_name' argument of drop_table cannot be an empty string!")
455+
456+
# Validate and sanitize table name
457+
validated_table_name = self._validate_sql_identifier(table_name, "table name")
458+
escaped_table_name = self._escape_sql_identifier(validated_table_name)
459+
364460
if not isinstance(raise_if_not_exists, bool):
365461
raise_auto_arg_type_error("raise_if_not_exists")
462+
463+
# Check if table exists using parameterized query
366464
if raise_if_not_exists:
367-
self.cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}';")
465+
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", (validated_table_name,))
368466
if self.cursor.fetchone() is None:
369-
raise ValueError(f"Table '{table_name}' does not exist.")
370-
self.cursor.execute(f"DROP TABLE IF EXISTS '{table_name}';")
467+
raise ValueError(f"Table '{validated_table_name}' does not exist.")
468+
469+
# Execute DROP statement with escaped identifier
470+
self.cursor.execute(f"DROP TABLE IF EXISTS {escaped_table_name};")
371471

372472
# def read(self):
373473

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
"""
2+
Test SQL injection protection in create_table and drop_table methods.
3+
"""
4+
5+
import pytest
6+
import sys
7+
from pathlib import Path
8+
9+
# Add the src directory to the path to import the module
10+
ROOT = Path(__file__).resolve().parent.parent
11+
sys.path.insert(0, str(ROOT / "stat_log_db" / "src"))
12+
13+
from stat_log_db.db import MemDB
14+
15+
16+
@pytest.fixture
17+
def mem_db():
18+
"""Create a test in-memory database and clean up after tests."""
19+
sl_db = MemDB({
20+
"is_mem": True,
21+
"fkey_constraint": True
22+
})
23+
con = sl_db.init_db(True)
24+
yield con
25+
# Cleanup
26+
sl_db.close_db()
27+
28+
29+
class TestSQLInjectionProtection:
30+
"""Test class for SQL injection protection in database operations."""
31+
32+
def test_malicious_table_name_create(self, mem_db):
33+
"""Test that malicious SQL injection in table names is rejected."""
34+
with pytest.raises(ValueError, match="Invalid SQL table name"):
35+
mem_db.create_table("test'; DROP TABLE users; --", [('notes', 'TEXT')], False, True)
36+
37+
def test_reserved_word_table_name(self, mem_db):
38+
"""Test that SQL reserved words are rejected as table names."""
39+
with pytest.raises(ValueError, match="is a reserved word"):
40+
mem_db.create_table("select", [('notes', 'TEXT')], False, True)
41+
42+
def test_invalid_characters_table_name(self, mem_db):
43+
"""Test that invalid characters in table names are rejected."""
44+
with pytest.raises(ValueError, match="Invalid SQL table name"):
45+
mem_db.create_table("test-table", [('notes', 'TEXT')], False, True)
46+
47+
def test_malicious_column_name(self, mem_db):
48+
"""Test that malicious SQL injection in column names is rejected."""
49+
with pytest.raises(ValueError, match="Invalid SQL column name"):
50+
mem_db.create_table("test_table", [('notes\'; DROP TABLE users; --', 'TEXT')], False, True)
51+
52+
def test_invalid_column_type(self, mem_db):
53+
"""Test that invalid/malicious column types are rejected."""
54+
with pytest.raises(ValueError, match="Unsupported column type"):
55+
mem_db.create_table("test_table", [('notes', 'MALICIOUS_TYPE; DROP TABLE users; --')], False, True)
56+
57+
def test_valid_table_creation(self, mem_db):
58+
"""Test that valid table creation works correctly."""
59+
# This should not raise any exception
60+
mem_db.create_table("test_table", [('notes', 'TEXT'), ('count', 'INTEGER')], False, True)
61+
62+
# Verify table was created by attempting to insert data
63+
mem_db.execute("INSERT INTO test_table (notes, count) VALUES (?, ?);", ("test note", 42))
64+
mem_db.commit()
65+
66+
# Verify data was inserted
67+
mem_db.execute("SELECT * FROM test_table;")
68+
result = mem_db.fetchall()
69+
assert len(result) == 1
70+
assert result[0][1] == "test note" # Column 0 is auto-increment id
71+
assert result[0][2] == 42
72+
73+
def test_malicious_drop_table_name(self, mem_db):
74+
"""Test that malicious SQL injection in drop table is rejected."""
75+
# First create a valid table
76+
mem_db.create_table("test_table", [('notes', 'TEXT')], False, True)
77+
78+
# Then try to drop with malicious name
79+
with pytest.raises(ValueError, match="Invalid SQL table name"):
80+
mem_db.drop_table("test_table'; DROP TABLE sqlite_master; --", False)
81+
82+
def test_valid_drop_table(self, mem_db):
83+
"""Test that valid table dropping works correctly."""
84+
# Create a table first
85+
mem_db.create_table("test_table", [('notes', 'TEXT')], False, True)
86+
87+
# Verify it exists by checking sqlite_master
88+
mem_db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", ("test_table",))
89+
assert mem_db.fetchone() is not None
90+
91+
# Drop the table
92+
mem_db.drop_table("test_table", False)
93+
94+
# Verify it's gone
95+
mem_db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?;", ("test_table",))
96+
assert mem_db.fetchone() is None
97+
98+
def test_empty_table_name_create(self, mem_db):
99+
"""Test that empty table names are rejected."""
100+
with pytest.raises(ValueError, match="cannot be an empty string"):
101+
mem_db.create_table("", [('notes', 'TEXT')], False, True)
102+
103+
def test_empty_table_name_drop(self, mem_db):
104+
"""Test that empty table names are rejected in drop operations."""
105+
with pytest.raises(ValueError, match="cannot be an empty string"):
106+
mem_db.drop_table("", False)
107+
108+
def test_empty_column_name(self, mem_db):
109+
"""Test that empty column names are rejected."""
110+
with pytest.raises(ValueError, match="cannot be empty"):
111+
mem_db.create_table("test_table", [('', 'TEXT')], False, True)
112+
113+
def test_column_name_with_numbers(self, mem_db):
114+
"""Test that column names with numbers are allowed."""
115+
mem_db.create_table("test_table", [('column1', 'TEXT'), ('column_2', 'INTEGER')], False, True)
116+
117+
def test_table_name_with_underscore(self, mem_db):
118+
"""Test that table names starting with underscore are allowed."""
119+
mem_db.create_table("_test_table", [('notes', 'TEXT')], False, True)
120+
121+
def test_valid_column_types(self, mem_db):
122+
"""Test that all supported column types work correctly."""
123+
valid_types = [
124+
('text_col', 'TEXT'),
125+
('int_col', 'INTEGER'),
126+
('real_col', 'REAL'),
127+
('blob_col', 'BLOB'),
128+
('numeric_col', 'NUMERIC'),
129+
('varchar_col', 'VARCHAR(255)'),
130+
('decimal_col', 'DECIMAL(10,2)')
131+
]
132+
133+
mem_db.create_table("type_test_table", valid_types, False, True)
134+
135+
def test_case_insensitive_reserved_words(self, mem_db):
136+
"""Test that reserved words are caught regardless of case."""
137+
with pytest.raises(ValueError, match="is a reserved word"):
138+
mem_db.create_table("SELECT", [('notes', 'TEXT')], False, True)
139+
140+
with pytest.raises(ValueError, match="is a reserved word"):
141+
mem_db.create_table("Select", [('notes', 'TEXT')], False, True)
142+
143+
def test_raise_if_exists_functionality(self, mem_db):
144+
"""Test the raise_if_exists parameter works correctly."""
145+
# Create a table
146+
mem_db.create_table("test_table", [('notes', 'TEXT')], False, True)
147+
148+
# Try to create the same table with raise_if_exists=True (should fail)
149+
with pytest.raises(ValueError, match="already exists"):
150+
mem_db.create_table("test_table", [('notes', 'TEXT')], False, True)
151+
152+
# Try to create the same table with raise_if_exists=False (should succeed)
153+
mem_db.create_table("test_table", [('notes', 'TEXT')], False, False)
154+
155+
def test_raise_if_not_exists_functionality(self, mem_db):
156+
"""Test the raise_if_not_exists parameter works correctly."""
157+
# Try to drop non-existent table with raise_if_not_exists=True (should fail)
158+
with pytest.raises(ValueError, match="does not exist"):
159+
mem_db.drop_table("nonexistent_table", True)
160+
161+
# Try to drop non-existent table with raise_if_not_exists=False (should succeed)
162+
mem_db.drop_table("nonexistent_table", False)
163+
164+
def test_special_characters_rejection(self, mem_db):
165+
"""Test that various special characters are properly rejected."""
166+
special_chars = [
167+
"table;name",
168+
"table'name",
169+
'table"name',
170+
"table name", # space
171+
"table-name", # hyphen
172+
"table.name", # dot
173+
"table(name)", # parentheses
174+
"table[name]", # brackets
175+
"table{name}", # braces
176+
"table@name", # at symbol
177+
"table#name", # hash
178+
"table$name", # dollar (should be rejected by our implementation)
179+
]
180+
181+
for table_name in special_chars:
182+
with pytest.raises(ValueError, match="Invalid SQL"):
183+
mem_db.create_table(table_name, [('notes', 'TEXT')], False, True)

0 commit comments

Comments
 (0)