diff --git a/Makefile b/Makefile index ce57ffdc..d5723162 100644 --- a/Makefile +++ b/Makefile @@ -129,15 +129,15 @@ release: ## Bump version and create re clean: ## Cleanup temporary build artifacts @echo "${INFO} Cleaning working directory... 🧹" @rm -rf .pytest_cache .ruff_cache .hypothesis build/ -rf dist/ .eggs/ .coverage coverage.xml coverage.json htmlcov/ .pytest_cache tests/.pytest_cache tests/**/.pytest_cache .mypy_cache .unasyncd_cache/ .auto_pytabs_cache >/dev/null 2>&1 - @find . -name '*.egg-info' -exec rm -rf {} + >/dev/null 2>&1 - @find . -type f -name '*.egg' -exec rm -f {} + >/dev/null 2>&1 - @find . -name '*.pyc' -exec rm -f {} + >/dev/null 2>&1 - @find . -name '*.pyo' -exec rm -f {} + >/dev/null 2>&1 - @find . -name '*~' -exec rm -f {} + >/dev/null 2>&1 - @find . -name '__pycache__' -exec rm -rf {} + >/dev/null 2>&1 - @find . -name '.ipynb_checkpoints' -exec rm -rf {} + >/dev/null 2>&1 - @find . -name '*.so' -exec rm -f {} + >/dev/null 2>&1 - @find . -name '*.c' -exec rm -f {} + >/dev/null 2>&1 + @find . \( -path ./.venv -o -path ./.git \) -prune -o -name '*.egg-info' -exec rm -rf {} + >/dev/null 2>&1 + @find . \( -path ./.venv -o -path ./.git \) -prune -o -type f -name '*.egg' -exec rm -f {} + >/dev/null 2>&1 + @find . \( -path ./.venv -o -path ./.git \) -prune -o -name '*.pyc' -exec rm -f {} + >/dev/null 2>&1 + @find . \( -path ./.venv -o -path ./.git \) -prune -o -name '*.pyo' -exec rm -f {} + >/dev/null 2>&1 + @find . \( -path ./.venv -o -path ./.git \) -prune -o -name '*~' -exec rm -f {} + >/dev/null 2>&1 + @find . \( -path ./.venv -o -path ./.git \) -prune -o -type d -name '__pycache__' -exec rm -rf {} + >/dev/null 2>&1 + @find . \( -path ./.venv -o -path ./.git \) -prune -o -name '.ipynb_checkpoints' -exec rm -rf {} + >/dev/null 2>&1 + @find . \( -path ./.venv -o -path ./.git \) -prune -o -name '*.so' -exec rm -f {} + >/dev/null 2>&1 + @find . \( -path ./.venv -o -path ./.git \) -prune -o -name '*.c' -exec rm -f {} + >/dev/null 2>&1 @echo "${OK} Working directory cleaned" $(MAKE) docs-clean diff --git a/pyproject.toml b/pyproject.toml index b99a8f60..7af7d1d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -200,7 +200,7 @@ opt_level = "3" # Maximum optimization (0-3) allow_dirty = true commit = false commit_args = "--no-verify" -current_version = "0.15.0" +current_version = "0.26.0" ignore_missing_files = false ignore_missing_version = false message = "chore(release): bump to v{new_version}" diff --git a/sqlspec/adapters/adbc/config.py b/sqlspec/adapters/adbc/config.py index ba4aeef4..61a10ed4 100644 --- a/sqlspec/adapters/adbc/config.py +++ b/sqlspec/adapters/adbc/config.py @@ -79,6 +79,7 @@ def __init__( statement_config: StatementConfig | None = None, driver_features: dict[str, Any] | None = None, bind_key: str | None = None, + extension_config: "dict[str, dict[str, Any]] | None" = None, ) -> None: """Initialize configuration. @@ -88,6 +89,7 @@ def __init__( statement_config: Default SQL statement configuration driver_features: Driver feature configuration bind_key: Optional unique identifier for this configuration + extension_config: Extension-specific configuration (e.g., Litestar plugin settings) """ if connection_config is None: connection_config = {} @@ -108,6 +110,7 @@ def __init__( statement_config=statement_config, driver_features=driver_features or {}, bind_key=bind_key, + extension_config=extension_config, ) def _resolve_driver_name(self) -> str: diff --git a/sqlspec/adapters/adbc/litestar/__init__.py b/sqlspec/adapters/adbc/litestar/__init__.py new file mode 100644 index 00000000..25bc6130 --- /dev/null +++ b/sqlspec/adapters/adbc/litestar/__init__.py @@ -0,0 +1,5 @@ +"""Litestar integration for ADBC adapter.""" + +from sqlspec.adapters.adbc.litestar.store import ADBCStore + +__all__ = ("ADBCStore",) diff --git a/sqlspec/adapters/adbc/litestar/store.py b/sqlspec/adapters/adbc/litestar/store.py new file mode 100644 index 00000000..9ec426d6 --- /dev/null +++ b/sqlspec/adapters/adbc/litestar/store.py @@ -0,0 +1,506 @@ +"""ADBC session store for Litestar integration with multi-dialect support. + +ADBC (Arrow Database Connectivity) supports multiple database backends including +PostgreSQL, SQLite, DuckDB, BigQuery, MySQL, and Snowflake. This store automatically +detects the dialect and adapts SQL syntax accordingly. + +Supports: +- PostgreSQL: BYTEA data type, TIMESTAMPTZ, $1 parameters, ON CONFLICT +- SQLite: BLOB data type, DATETIME, ? parameters, INSERT OR REPLACE +- DuckDB: BLOB data type, TIMESTAMP, ? parameters, ON CONFLICT +- MySQL/MariaDB: BLOB data type, DATETIME, %s parameters, ON DUPLICATE KEY UPDATE +- BigQuery: BYTES data type, TIMESTAMP, @param parameters, MERGE +- Snowflake: BINARY data type, TIMESTAMP WITH TIME ZONE, ? parameters, MERGE +""" + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from sqlspec.extensions.litestar.store import BaseSQLSpecStore +from sqlspec.utils.logging import get_logger +from sqlspec.utils.sync_tools import async_ + +if TYPE_CHECKING: + from sqlspec.adapters.adbc.config import AdbcConfig + +logger = get_logger("adapters.adbc.litestar.store") + +__all__ = ("ADBCStore",) + + +class ADBCStore(BaseSQLSpecStore["AdbcConfig"]): + """ADBC session store using synchronous ADBC driver. + + Implements server-side session storage for Litestar using ADBC + (Arrow Database Connectivity) via the synchronous driver. Uses + Litestar's sync_to_thread utility to provide an async interface + compatible with the Store protocol. + + ADBC supports multiple database backends (PostgreSQL, SQLite, DuckDB, etc.). + The SQL schema is optimized for PostgreSQL by default, but can work with + other backends that support TIMESTAMPTZ and BYTEA equivalents. + + Provides efficient session management with: + - Sync operations wrapped for async compatibility + - INSERT ON CONFLICT (UPSERT) for PostgreSQL + - Automatic expiration handling with TIMESTAMPTZ + - Efficient cleanup of expired sessions + + Args: + config: AdbcConfig instance. + table_name: Name of the session table. Defaults to "sessions". + + Example: + from sqlspec.adapters.adbc import AdbcConfig + from sqlspec.adapters.adbc.litestar.store import ADBCStore + + config = AdbcConfig( + connection_config={ + "uri": "postgresql://user:pass@localhost/db" + } + ) + store = ADBCStore(config) + await store.create_table() + """ + + __slots__ = ("_dialect",) + + def __init__(self, config: "AdbcConfig", table_name: str = "litestar_session") -> None: + """Initialize ADBC session store. + + Args: + config: AdbcConfig instance. + table_name: Name of the session table. + """ + super().__init__(config, table_name) + self._dialect: str | None = None + + def _get_dialect(self) -> str: + """Get the database dialect, caching it after first access. + + Returns: + Dialect name (postgres, sqlite, duckdb, mysql, bigquery, snowflake). + """ + if self._dialect is not None: + return self._dialect + + with self._config.provide_session() as driver: + dialect_value = getattr(driver, "dialect", None) + self._dialect = str(dialect_value) if dialect_value else "postgres" + + assert self._dialect is not None + return self._dialect + + def _get_create_table_sql(self) -> str: + """Get dialect-specific CREATE TABLE SQL for ADBC. + + Returns: + SQL statement to create the sessions table with proper indexes. + + Notes: + Automatically adapts to the detected database dialect: + - PostgreSQL: BYTEA, TIMESTAMPTZ with partial index + - SQLite: BLOB, DATETIME + - DuckDB: BLOB, TIMESTAMP + - MySQL/MariaDB: BLOB, DATETIME + - BigQuery: BYTES, TIMESTAMP + - Snowflake: BINARY, TIMESTAMP WITH TIME ZONE + """ + dialect = self._get_dialect() + + if dialect in {"postgres", "postgresql"}: + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id TEXT PRIMARY KEY, + data BYTEA NOT NULL, + expires_at TIMESTAMPTZ + ); + CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at) WHERE expires_at IS NOT NULL; + """ + + if dialect == "sqlite": + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id TEXT PRIMARY KEY, + data BLOB NOT NULL, + expires_at DATETIME + ); + CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at); + """ + + if dialect == "duckdb": + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id VARCHAR PRIMARY KEY, + data BLOB NOT NULL, + expires_at TIMESTAMP + ); + CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at); + """ + + if dialect in {"mysql", "mariadb"}: + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id VARCHAR(255) PRIMARY KEY, + data BLOB NOT NULL, + expires_at DATETIME + ); + CREATE INDEX idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at); + """ + + if dialect == "bigquery": + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id STRING NOT NULL, + data BYTES NOT NULL, + expires_at TIMESTAMP + ); + CREATE INDEX idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at); + """ + + if dialect == "snowflake": + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id VARCHAR(255) PRIMARY KEY, + data BINARY NOT NULL, + expires_at TIMESTAMP WITH TIME ZONE + ); + CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at); + """ + + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id TEXT PRIMARY KEY, + data BYTEA NOT NULL, + expires_at TIMESTAMPTZ + ); + CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at); + """ + + def _get_param_placeholder(self, position: int) -> str: + """Get the parameter placeholder syntax for the current dialect. + + Args: + position: 1-based parameter position. + + Returns: + Parameter placeholder string (e.g., '$1', '?', '%s', '@param1'). + """ + dialect = self._get_dialect() + + if dialect in {"postgres", "postgresql"}: + return f"${position}" + if dialect in {"mysql", "mariadb"}: + return "%s" + if dialect == "bigquery": + return f"@param{position}" + return "?" + + def _get_current_timestamp_expr(self) -> str: + """Get the current timestamp expression for the current dialect. + + Returns: + SQL expression for getting current timestamp with timezone. + """ + dialect = self._get_dialect() + + if dialect in {"postgres", "postgresql"}: + return "CURRENT_TIMESTAMP AT TIME ZONE 'UTC'" + if dialect in {"mysql", "mariadb"}: + return "UTC_TIMESTAMP()" + if dialect == "bigquery": + return "CURRENT_TIMESTAMP()" + if dialect == "snowflake": + return "CONVERT_TIMEZONE('UTC', CURRENT_TIMESTAMP())" + return "CURRENT_TIMESTAMP" + + def _create_table(self) -> None: + """Synchronous implementation of create_table using ADBC driver.""" + sql_text = self._get_create_table_sql() + with self._config.provide_session() as driver: + for statement in sql_text.strip().split(";"): + statement = statement.strip() + if statement: + driver.execute(statement) + driver.commit() + logger.debug("Created session table: %s", self._table_name) + + def _get_drop_table_sql(self) -> "list[str]": + """Get dialect-specific DROP TABLE SQL statements for ADBC. + + Returns: + List of SQL statements to drop indexes and table. + """ + dialect = self._get_dialect() + + if dialect in {"mysql", "mariadb"}: + return [ + f"DROP INDEX idx_{self._table_name}_expires_at ON {self._table_name}", + f"DROP TABLE IF EXISTS {self._table_name}", + ] + + return [f"DROP INDEX IF EXISTS idx_{self._table_name}_expires_at", f"DROP TABLE IF EXISTS {self._table_name}"] + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + await async_(self._create_table)() + + def _get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Synchronous implementation of get using ADBC driver.""" + p1 = self._get_param_placeholder(1) + current_ts = self._get_current_timestamp_expr() + + sql = f""" + SELECT data, expires_at FROM {self._table_name} + WHERE session_id = {p1} + AND (expires_at IS NULL OR expires_at > {current_ts}) + """ + + with self._config.provide_session() as driver: + result = driver.select_one_or_none(sql, key) + + if result is None: + return None + + data = result["data"] + expires_at = result["expires_at"] + + if renew_for is not None and expires_at is not None: + new_expires_at = self._calculate_expires_at(renew_for) + p1_update = self._get_param_placeholder(1) + p2_update = self._get_param_placeholder(2) + update_sql = f""" + UPDATE {self._table_name} + SET expires_at = {p1_update} + WHERE session_id = {p2_update} + """ + driver.execute(update_sql, new_expires_at, key) + driver.commit() + + return bytes(data) + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key. + + Args: + key: Session ID to retrieve. + renew_for: If given, renew the expiry time for this duration. + + Returns: + Session data as bytes if found and not expired, None otherwise. + """ + return await async_(self._get)(key, renew_for) + + def _set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Synchronous implementation of set using ADBC driver with dialect-specific UPSERT.""" + data = self._value_to_bytes(value) + expires_at = self._calculate_expires_at(expires_in) + dialect = self._get_dialect() + + p1 = self._get_param_placeholder(1) + p2 = self._get_param_placeholder(2) + p3 = self._get_param_placeholder(3) + + if dialect in {"postgres", "postgresql", "sqlite", "duckdb"}: + if dialect == "sqlite": + sql = f""" + INSERT OR REPLACE INTO {self._table_name} (session_id, data, expires_at) + VALUES ({p1}, {p2}, {p3}) + """ + else: + sql = f""" + INSERT INTO {self._table_name} (session_id, data, expires_at) + VALUES ({p1}, {p2}, {p3}) + ON CONFLICT (session_id) DO UPDATE + SET data = EXCLUDED.data, expires_at = EXCLUDED.expires_at + """ + elif dialect in {"mysql", "mariadb"}: + sql = f""" + INSERT INTO {self._table_name} (session_id, data, expires_at) + VALUES ({p1}, {p2}, {p3}) + ON DUPLICATE KEY UPDATE data = VALUES(data), expires_at = VALUES(expires_at) + """ + elif dialect in {"bigquery", "snowflake"}: + with self._config.provide_session() as driver: + check_sql = f"SELECT COUNT(*) as count FROM {self._table_name} WHERE session_id = {p1}" + result = driver.select_one(check_sql, key) + exists = result and result.get("count", 0) > 0 + + if exists: + sql = f""" + UPDATE {self._table_name} + SET data = {p1}, expires_at = {p2} + WHERE session_id = {p3} + """ + driver.execute(sql, data, expires_at, key) + else: + sql = f""" + INSERT INTO {self._table_name} (session_id, data, expires_at) + VALUES ({p1}, {p2}, {p3}) + """ + driver.execute(sql, key, data, expires_at) + driver.commit() + return + else: + sql = f""" + INSERT INTO {self._table_name} (session_id, data, expires_at) + VALUES ({p1}, {p2}, {p3}) + ON CONFLICT (session_id) DO UPDATE + SET data = EXCLUDED.data, expires_at = EXCLUDED.expires_at + """ + + with self._config.provide_session() as driver: + driver.execute(sql, key, data, expires_at) + driver.commit() + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value. + + Args: + key: Session ID. + value: Session data. + expires_in: Time until expiration. + """ + await async_(self._set)(key, value, expires_in) + + def _delete(self, key: str) -> None: + """Synchronous implementation of delete using ADBC driver.""" + p1 = self._get_param_placeholder(1) + sql = f"DELETE FROM {self._table_name} WHERE session_id = {p1}" + + with self._config.provide_session() as driver: + driver.execute(sql, key) + driver.commit() + + async def delete(self, key: str) -> None: + """Delete a session by key. + + Args: + key: Session ID to delete. + """ + await async_(self._delete)(key) + + def _delete_all(self) -> None: + """Synchronous implementation of delete_all using ADBC driver.""" + + sql = f"DELETE FROM {self._table_name}" + + with self._config.provide_session() as driver: + driver.execute(sql) + driver.commit() + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + await async_(self._delete_all)() + + def _exists(self, key: str) -> bool: + """Synchronous implementation of exists using ADBC driver.""" + + p1 = self._get_param_placeholder(1) + current_ts = self._get_current_timestamp_expr() + + sql = f""" + SELECT 1 FROM {self._table_name} + WHERE session_id = {p1} + AND (expires_at IS NULL OR expires_at > {current_ts}) + """ + + with self._config.provide_session() as driver: + return bool(driver.select_one_or_none(sql, key) is not None) + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired. + + Args: + key: Session ID to check. + + Returns: + True if the session exists and is not expired. + """ + return await async_(self._exists)(key) + + def _expires_in(self, key: str) -> "int | None": + """Synchronous implementation of expires_in using ADBC driver.""" + p1 = self._get_param_placeholder(1) + sql = f""" + SELECT expires_at FROM {self._table_name} + WHERE session_id = {p1} + """ + + with self._config.provide_session() as driver: + result = driver.select_one(sql, key) + + if result is None or result.get("expires_at") is None: + return None + + expires_at = result["expires_at"] + + if not isinstance(expires_at, datetime): + return None + + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + + now = datetime.now(timezone.utc) + + if expires_at <= now: + return 0 + + delta = expires_at - now + return int(delta.total_seconds()) + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires. + + Args: + key: Session ID to check. + + Returns: + Seconds until expiration, or None if no expiry or key doesn't exist. + """ + return await async_(self._expires_in)(key) + + def _delete_expired(self) -> int: + """Synchronous implementation of delete_expired using ADBC driver.""" + current_ts = self._get_current_timestamp_expr() + dialect = self._get_dialect() + + if dialect in {"postgres", "postgresql"}: + sql = f"DELETE FROM {self._table_name} WHERE expires_at <= {current_ts} RETURNING session_id" + else: + count_sql = f"SELECT COUNT(*) as count FROM {self._table_name} WHERE expires_at <= {current_ts}" + delete_sql = f"DELETE FROM {self._table_name} WHERE expires_at <= {current_ts}" + + with self._config.provide_session() as driver: + result = driver.select_one(count_sql) + count = result.get("count", 0) if result else 0 + + if count > 0: + driver.execute(delete_sql) + driver.commit() + logger.debug("Cleaned up %d expired sessions", count) + + return count + + with self._config.provide_session() as driver: + exec_result = driver.execute(sql) + driver.commit() + count = exec_result.rows_affected + if count > 0: + logger.debug("Cleaned up %d expired sessions", count) + return count + + async def delete_expired(self) -> int: + """Delete all expired sessions. + + Returns: + Number of sessions deleted. + """ + return await async_(self._delete_expired)() diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index 9fc7c613..8f007042 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -63,6 +63,7 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, + extension_config: "dict[str, dict[str, Any]] | None" = None, ) -> None: """Initialize AioSQLite configuration. @@ -73,6 +74,7 @@ def __init__( statement_config: Optional statement configuration. driver_features: Optional driver feature configuration. bind_key: Optional unique identifier for this configuration. + extension_config: Extension-specific configuration (e.g., Litestar plugin settings) """ config_dict = dict(pool_config) if pool_config else {} @@ -96,6 +98,7 @@ def __init__( statement_config=statement_config or aiosqlite_statement_config, driver_features=driver_features or {}, bind_key=bind_key, + extension_config=extension_config, ) def _get_pool_config_dict(self) -> "dict[str, Any]": diff --git a/sqlspec/adapters/aiosqlite/litestar/__init__.py b/sqlspec/adapters/aiosqlite/litestar/__init__.py new file mode 100644 index 00000000..f27b1a64 --- /dev/null +++ b/sqlspec/adapters/aiosqlite/litestar/__init__.py @@ -0,0 +1,5 @@ +"""Litestar integration for AioSQLite adapter.""" + +from sqlspec.adapters.aiosqlite.litestar.store import AiosqliteStore + +__all__ = ("AiosqliteStore",) diff --git a/sqlspec/adapters/aiosqlite/litestar/store.py b/sqlspec/adapters/aiosqlite/litestar/store.py new file mode 100644 index 00000000..0a8e0392 --- /dev/null +++ b/sqlspec/adapters/aiosqlite/litestar/store.py @@ -0,0 +1,280 @@ +"""AioSQLite session store for Litestar integration.""" + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from sqlspec.extensions.litestar.store import BaseSQLSpecStore +from sqlspec.utils.logging import get_logger + +if TYPE_CHECKING: + from sqlspec.adapters.aiosqlite.config import AiosqliteConfig + +logger = get_logger("adapters.aiosqlite.litestar.store") + +SECONDS_PER_DAY = 86400.0 +JULIAN_EPOCH = 2440587.5 + +__all__ = ("AiosqliteStore",) + + +class AiosqliteStore(BaseSQLSpecStore["AiosqliteConfig"]): + """SQLite session store using AioSQLite driver. + + Implements server-side session storage for Litestar using SQLite + via the AioSQLite driver. Provides efficient session management with: + - Native async SQLite operations + - INSERT OR REPLACE for UPSERT functionality + - Automatic expiration handling + - Efficient cleanup of expired sessions + + Args: + config: AiosqliteConfig instance. + table_name: Name of the session table. Defaults to "sessions". + + Example: + from sqlspec.adapters.aiosqlite import AiosqliteConfig + from sqlspec.adapters.aiosqlite.litestar.store import AiosqliteStore + + config = AiosqliteConfig(database=":memory:") + store = AiosqliteStore(config) + await store.create_table() + """ + + __slots__ = () + + def __init__(self, config: "AiosqliteConfig", table_name: str = "litestar_session") -> None: + """Initialize AioSQLite session store. + + Args: + config: AiosqliteConfig instance. + table_name: Name of the session table. + """ + super().__init__(config, table_name) + + def _get_create_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL. + + Returns: + SQL statement to create the sessions table with proper indexes. + + Notes: + - Uses REAL type for expires_at (stores Julian Day number) + - Julian Day enables direct comparison with julianday('now') + - Partial index WHERE expires_at IS NOT NULL reduces index size + - This approach ensures the index is actually used by query optimizer + - Table name is internally controlled, not user input (S608 suppressed) + """ + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id TEXT PRIMARY KEY, + data BLOB NOT NULL, + expires_at REAL + ); + CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at) WHERE expires_at IS NOT NULL; + """ + + def _get_drop_table_sql(self) -> "list[str]": + """Get SQLite DROP TABLE SQL statements. + + Returns: + List of SQL statements to drop indexes and table. + """ + return [f"DROP INDEX IF EXISTS idx_{self._table_name}_expires_at", f"DROP TABLE IF EXISTS {self._table_name}"] + + def _datetime_to_julian(self, dt: "datetime | None") -> "float | None": + """Convert datetime to Julian Day number for SQLite storage. + + Args: + dt: Datetime to convert (must be UTC-aware). + + Returns: + Julian Day number as REAL, or None if dt is None. + + Notes: + Julian Day number is days since November 24, 4714 BCE (proleptic Gregorian). + This enables direct comparison with julianday('now') in SQL queries. + """ + if dt is None: + return None + + epoch = datetime(1970, 1, 1, tzinfo=timezone.utc) + delta_days = (dt - epoch).total_seconds() / SECONDS_PER_DAY + return JULIAN_EPOCH + delta_days + + def _julian_to_datetime(self, julian: "float | None") -> "datetime | None": + """Convert Julian Day number back to datetime. + + Args: + julian: Julian Day number. + + Returns: + UTC-aware datetime, or None if julian is None. + """ + if julian is None: + return None + + days_since_epoch = julian - JULIAN_EPOCH + timestamp = days_since_epoch * SECONDS_PER_DAY + return datetime.fromtimestamp(timestamp, tz=timezone.utc) + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + sql = self._get_create_table_sql() + async with self._config.provide_connection() as conn: + await conn.executescript(sql) + logger.debug("Created session table: %s", self._table_name) + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key. + + Args: + key: Session ID to retrieve. + renew_for: If given, renew the expiry time for this duration. + + Returns: + Session data as bytes if found and not expired, None otherwise. + """ + sql = f""" + SELECT data, expires_at FROM {self._table_name} + WHERE session_id = ? + AND (expires_at IS NULL OR julianday(expires_at) > julianday('now')) + """ + + async with self._config.provide_connection() as conn: + async with conn.execute(sql, (key,)) as cursor: + row = await cursor.fetchone() + + if row is None: + return None + + data, expires_at_julian = row + + if renew_for is not None and expires_at_julian is not None: + new_expires_at = self._calculate_expires_at(renew_for) + new_expires_at_julian = self._datetime_to_julian(new_expires_at) + if new_expires_at_julian is not None: + update_sql = f""" + UPDATE {self._table_name} + SET expires_at = ? + WHERE session_id = ? + """ + await conn.execute(update_sql, (new_expires_at_julian, key)) + await conn.commit() + + return bytes(data) + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value. + + Args: + key: Session ID. + value: Session data. + expires_in: Time until expiration. + + Notes: + Stores expires_at as Julian Day number (REAL) for optimal index usage. + """ + data = self._value_to_bytes(value) + expires_at = self._calculate_expires_at(expires_in) + expires_at_julian = self._datetime_to_julian(expires_at) + + sql = f""" + INSERT OR REPLACE INTO {self._table_name} (session_id, data, expires_at) + VALUES (?, ?, ?) + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, (key, data, expires_at_julian)) + await conn.commit() + + async def delete(self, key: str) -> None: + """Delete a session by key. + + Args: + key: Session ID to delete. + """ + sql = f"DELETE FROM {self._table_name} WHERE session_id = ?" + + async with self._config.provide_connection() as conn: + await conn.execute(sql, (key,)) + await conn.commit() + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + sql = f"DELETE FROM {self._table_name}" + + async with self._config.provide_connection() as conn: + await conn.execute(sql) + await conn.commit() + logger.debug("Deleted all sessions from table: %s", self._table_name) + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired. + + Args: + key: Session ID to check. + + Returns: + True if the session exists and is not expired. + """ + sql = f""" + SELECT 1 FROM {self._table_name} + WHERE session_id = ? + AND (expires_at IS NULL OR julianday(expires_at) > julianday('now')) + """ + + async with self._config.provide_connection() as conn, conn.execute(sql, (key,)) as cursor: + result = await cursor.fetchone() + return result is not None + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires. + + Args: + key: Session ID to check. + + Returns: + Seconds until expiration, or None if no expiry or key doesn't exist. + """ + sql = f""" + SELECT expires_at FROM {self._table_name} + WHERE session_id = ? + """ + + async with self._config.provide_connection() as conn: + async with conn.execute(sql, (key,)) as cursor: + row = await cursor.fetchone() + + if row is None or row[0] is None: + return None + + expires_at_julian = row[0] + expires_at = self._julian_to_datetime(expires_at_julian) + + if expires_at is None: + return None + + now = datetime.now(timezone.utc) + + if expires_at <= now: + return 0 + + delta = expires_at - now + return int(delta.total_seconds()) + + async def delete_expired(self) -> int: + """Delete all expired sessions. + + Returns: + Number of sessions deleted. + """ + sql = f"DELETE FROM {self._table_name} WHERE julianday(expires_at) <= julianday('now')" + + async with self._config.provide_connection() as conn: + cursor = await conn.execute(sql) + await conn.commit() + count = cursor.rowcount + if count > 0: + logger.debug("Cleaned up %d expired sessions", count) + return count diff --git a/sqlspec/adapters/asyncmy/config.py b/sqlspec/adapters/asyncmy/config.py index 2cbc6aef..bbeef967 100644 --- a/sqlspec/adapters/asyncmy/config.py +++ b/sqlspec/adapters/asyncmy/config.py @@ -72,6 +72,7 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, + extension_config: "dict[str, dict[str, Any]] | None" = None, ) -> None: """Initialize Asyncmy configuration. @@ -82,6 +83,7 @@ def __init__( statement_config: Statement configuration override driver_features: Driver feature configuration bind_key: Optional unique identifier for this configuration + extension_config: Extension-specific configuration (e.g., Litestar plugin settings) """ processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {} if "extra" in processed_pool_config: @@ -103,6 +105,7 @@ def __init__( statement_config=statement_config, driver_features=driver_features or {}, bind_key=bind_key, + extension_config=extension_config, ) async def _create_pool(self) -> "AsyncmyPool": # pyright: ignore diff --git a/sqlspec/adapters/asyncmy/litestar/__init__.py b/sqlspec/adapters/asyncmy/litestar/__init__.py new file mode 100644 index 00000000..4f3cd516 --- /dev/null +++ b/sqlspec/adapters/asyncmy/litestar/__init__.py @@ -0,0 +1,5 @@ +"""Litestar integration for AsyncMy adapter.""" + +from sqlspec.adapters.asyncmy.litestar.store import AsyncmyStore + +__all__ = ("AsyncmyStore",) diff --git a/sqlspec/adapters/asyncmy/litestar/store.py b/sqlspec/adapters/asyncmy/litestar/store.py new file mode 100644 index 00000000..d066e6ec --- /dev/null +++ b/sqlspec/adapters/asyncmy/litestar/store.py @@ -0,0 +1,295 @@ +"""AsyncMy session store for Litestar integration.""" + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Final + +from sqlspec.extensions.litestar.store import BaseSQLSpecStore +from sqlspec.utils.logging import get_logger + +if TYPE_CHECKING: + from sqlspec.adapters.asyncmy.config import AsyncmyConfig + +logger = get_logger("adapters.asyncmy.litestar.store") + +__all__ = ("AsyncmyStore",) + +MYSQL_TABLE_NOT_FOUND_ERROR: Final = 1146 + + +class AsyncmyStore(BaseSQLSpecStore["AsyncmyConfig"]): + """MySQL/MariaDB session store using AsyncMy driver. + + Implements server-side session storage for Litestar using MySQL/MariaDB + via the AsyncMy driver. Provides efficient session management with: + - Native async MySQL operations + - UPSERT support using ON DUPLICATE KEY UPDATE + - Automatic expiration handling + - Efficient cleanup of expired sessions + - Timezone-aware expiration (stored as UTC in DATETIME) + + Args: + config: AsyncmyConfig instance. + table_name: Name of the session table. Defaults to "sessions". + + Example: + from sqlspec.adapters.asyncmy import AsyncmyConfig + from sqlspec.adapters.asyncmy.litestar.store import AsyncmyStore + + config = AsyncmyConfig(pool_config={"host": "localhost", ...}) + store = AsyncmyStore(config) + await store.create_table() + + Notes: + MySQL DATETIME is timezone-naive, so UTC datetimes are stored without + timezone info and timezone conversion is handled in Python layer. + """ + + __slots__ = () + + def __init__(self, config: "AsyncmyConfig", table_name: str = "litestar_session") -> None: + """Initialize AsyncMy session store. + + Args: + config: AsyncmyConfig instance. + table_name: Name of the session table. + """ + super().__init__(config, table_name) + + def _get_create_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL with optimized schema. + + Returns: + SQL statement to create the sessions table with proper indexes. + + Notes: + - Uses DATETIME(6) for microsecond precision timestamps + - MySQL doesn't have TIMESTAMPTZ, so we store UTC as timezone-naive + - LONGBLOB for large session data support (up to 4GB) + - InnoDB engine for ACID compliance and proper transaction support + - UTF8MB4 for full Unicode support (including emoji) + - Index on expires_at for efficient cleanup queries + - Auto-update of updated_at timestamp on row modification + - Table name is internally controlled, not user input (S608 suppressed) + """ + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id VARCHAR(255) PRIMARY KEY, + data LONGBLOB NOT NULL, + expires_at DATETIME(6), + created_at DATETIME(6) DEFAULT CURRENT_TIMESTAMP(6), + updated_at DATETIME(6) DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + INDEX idx_{self._table_name}_expires_at (expires_at) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + def _get_drop_table_sql(self) -> "list[str]": + """Get MySQL/MariaDB DROP TABLE SQL statements. + + Returns: + List of SQL statements to drop indexes and table. + """ + return [ + f"DROP INDEX idx_{self._table_name}_expires_at ON {self._table_name}", + f"DROP TABLE IF EXISTS {self._table_name}", + ] + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + sql = self._get_create_table_sql() + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql) + logger.debug("Created session table: %s", self._table_name) + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key. + + Args: + key: Session ID to retrieve. + renew_for: If given, renew the expiry time for this duration. + + Returns: + Session data as bytes if found and not expired, None otherwise. + + Notes: + Uses UTC_TIMESTAMP(6) for microsecond precision current time in MySQL. + Compares expires_at as UTC datetime (timezone-naive in MySQL). + """ + import asyncmy + + sql = f""" + SELECT data, expires_at FROM {self._table_name} + WHERE session_id = %s + AND (expires_at IS NULL OR expires_at > UTC_TIMESTAMP(6)) + """ + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (key,)) + row = await cursor.fetchone() + + if row is None: + return None + + data_value, expires_at = row + + if renew_for is not None and expires_at is not None: + new_expires_at = self._calculate_expires_at(renew_for) + if new_expires_at is not None: + naive_expires_at = new_expires_at.replace(tzinfo=None) + update_sql = f""" + UPDATE {self._table_name} + SET expires_at = %s, updated_at = UTC_TIMESTAMP(6) + WHERE session_id = %s + """ + await cursor.execute(update_sql, (naive_expires_at, key)) + await conn.commit() + + return bytes(data_value) + except asyncmy.errors.ProgrammingError as e: # pyright: ignore + if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return None + raise + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value. + + Args: + key: Session ID. + value: Session data. + expires_in: Time until expiration. + + Notes: + Uses INSERT ... ON DUPLICATE KEY UPDATE for efficient UPSERT. + Stores UTC datetime as timezone-naive DATETIME in MySQL. + Uses alias syntax (AS new) instead of deprecated VALUES() function. + """ + data = self._value_to_bytes(value) + expires_at = self._calculate_expires_at(expires_in) + naive_expires_at = expires_at.replace(tzinfo=None) if expires_at else None + + sql = f""" + INSERT INTO {self._table_name} (session_id, data, expires_at) + VALUES (%s, %s, %s) AS new + ON DUPLICATE KEY UPDATE + data = new.data, + expires_at = new.expires_at, + updated_at = UTC_TIMESTAMP(6) + """ + + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (key, data, naive_expires_at)) + await conn.commit() + + async def delete(self, key: str) -> None: + """Delete a session by key. + + Args: + key: Session ID to delete. + """ + sql = f"DELETE FROM {self._table_name} WHERE session_id = %s" + + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (key,)) + await conn.commit() + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + import asyncmy + + sql = f"DELETE FROM {self._table_name}" + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql) + await conn.commit() + logger.debug("Deleted all sessions from table: %s", self._table_name) + except asyncmy.errors.ProgrammingError as e: # pyright: ignore + if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + logger.debug("Table %s does not exist, skipping delete_all", self._table_name) + return + raise + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired. + + Args: + key: Session ID to check. + + Returns: + True if the session exists and is not expired. + + Notes: + Uses UTC_TIMESTAMP(6) for microsecond precision current time comparison. + """ + import asyncmy + + sql = f""" + SELECT 1 FROM {self._table_name} + WHERE session_id = %s + AND (expires_at IS NULL OR expires_at > UTC_TIMESTAMP(6)) + """ + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (key,)) + result = await cursor.fetchone() + return result is not None + except asyncmy.errors.ProgrammingError as e: # pyright: ignore + if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return False + raise + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires. + + Args: + key: Session ID to check. + + Returns: + Seconds until expiration, or None if no expiry or key doesn't exist. + + Notes: + MySQL DATETIME is timezone-naive, but we treat it as UTC. + Compare against UTC now in Python layer for accuracy. + """ + sql = f""" + SELECT expires_at FROM {self._table_name} + WHERE session_id = %s + """ + + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (key,)) + row = await cursor.fetchone() + + if row is None or row[0] is None: + return None + + expires_at_naive = row[0] + expires_at_utc = expires_at_naive.replace(tzinfo=timezone.utc) + now = datetime.now(timezone.utc) + + if expires_at_utc <= now: + return 0 + + delta = expires_at_utc - now + return int(delta.total_seconds()) + + async def delete_expired(self) -> int: + """Delete all expired sessions. + + Returns: + Number of sessions deleted. + + Notes: + Uses UTC_TIMESTAMP(6) for microsecond precision current time comparison. + ROW_COUNT() returns the number of affected rows. + """ + sql = f"DELETE FROM {self._table_name} WHERE expires_at <= UTC_TIMESTAMP(6)" + + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql) + await conn.commit() + count: int = cursor.rowcount + if count > 0: + logger.debug("Cleaned up %d expired sessions", count) + return count diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index 71e786da..3dea885b 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -85,6 +85,7 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "AsyncpgDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, + extension_config: "dict[str, dict[str, Any]] | None" = None, ) -> None: """Initialize AsyncPG configuration. @@ -95,6 +96,7 @@ def __init__( statement_config: Statement configuration override driver_features: Driver features configuration (TypedDict or dict) bind_key: Optional unique identifier for this configuration + extension_config: Extension-specific configuration (e.g., Litestar plugin settings) """ features_dict: dict[str, Any] = dict(driver_features) if driver_features else {} @@ -109,6 +111,7 @@ def __init__( statement_config=statement_config or asyncpg_statement_config, driver_features=features_dict, bind_key=bind_key, + extension_config=extension_config, ) def _get_pool_config_dict(self) -> "dict[str, Any]": diff --git a/sqlspec/adapters/asyncpg/litestar/__init__.py b/sqlspec/adapters/asyncpg/litestar/__init__.py new file mode 100644 index 00000000..7e34096b --- /dev/null +++ b/sqlspec/adapters/asyncpg/litestar/__init__.py @@ -0,0 +1,5 @@ +"""Litestar integration for AsyncPG adapter.""" + +from sqlspec.adapters.asyncpg.litestar.store import AsyncpgStore + +__all__ = ("AsyncpgStore",) diff --git a/sqlspec/adapters/asyncpg/litestar/store.py b/sqlspec/adapters/asyncpg/litestar/store.py new file mode 100644 index 00000000..8ab9aba6 --- /dev/null +++ b/sqlspec/adapters/asyncpg/litestar/store.py @@ -0,0 +1,249 @@ +"""AsyncPG session store for Litestar integration.""" + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from sqlspec.extensions.litestar.store import BaseSQLSpecStore +from sqlspec.utils.logging import get_logger + +if TYPE_CHECKING: + from sqlspec.adapters.asyncpg.config import AsyncpgConfig + +logger = get_logger("adapters.asyncpg.litestar.store") + +__all__ = ("AsyncpgStore",) + + +class AsyncpgStore(BaseSQLSpecStore["AsyncpgConfig"]): + """PostgreSQL session store using AsyncPG driver. + + Implements server-side session storage for Litestar using PostgreSQL + via the AsyncPG driver. Provides efficient session management with: + - Native async PostgreSQL operations + - UPSERT support using ON CONFLICT + - Automatic expiration handling + - Efficient cleanup of expired sessions + + Args: + config: AsyncpgConfig instance. + table_name: Name of the session table. Defaults to "litestar_session". + + Example: + from sqlspec.adapters.asyncpg import AsyncpgConfig + from sqlspec.adapters.asyncpg.litestar.store import AsyncpgStore + + config = AsyncpgConfig(pool_config={"dsn": "postgresql://..."}) + store = AsyncpgStore(config) + await store.create_table() + """ + + __slots__ = () + + def __init__(self, config: "AsyncpgConfig", table_name: str = "litestar_session") -> None: + """Initialize AsyncPG session store. + + Args: + config: AsyncpgConfig instance. + table_name: Name of the session table. + """ + super().__init__(config, table_name) + + def _get_create_table_sql(self) -> str: + """Get PostgreSQL CREATE TABLE SQL with optimized schema. + + Returns: + SQL statement to create the sessions table with proper indexes. + + Notes: + - Uses TIMESTAMPTZ for timezone-aware expiration timestamps + - Partial index WHERE expires_at IS NOT NULL reduces index size/maintenance + - FILLFACTOR 80 leaves space for HOT updates, reducing table bloat + - Audit columns (created_at, updated_at) help with debugging + - Table name is internally controlled, not user input (S608 suppressed) + """ + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id TEXT PRIMARY KEY, + data BYTEA NOT NULL, + expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at) WHERE expires_at IS NOT NULL; + + ALTER TABLE {self._table_name} SET ( + autovacuum_vacuum_scale_factor = 0.05, + autovacuum_analyze_scale_factor = 0.02 + ); + """ + + def _get_drop_table_sql(self) -> "list[str]": + """Get PostgreSQL DROP TABLE SQL statements. + + Returns: + List of SQL statements to drop indexes and table. + """ + return [f"DROP INDEX IF EXISTS idx_{self._table_name}_expires_at", f"DROP TABLE IF EXISTS {self._table_name}"] + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + sql = self._get_create_table_sql() + async with self._config.provide_connection() as conn: + await conn.execute(sql) + logger.debug("Created session table: %s", self._table_name) + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key. + + Args: + key: Session ID to retrieve. + renew_for: If given, renew the expiry time for this duration. + + Returns: + Session data as bytes if found and not expired, None otherwise. + + Notes: + Uses CURRENT_TIMESTAMP instead of NOW() for SQL standard compliance. + The query planner can use the partial index for expires_at > CURRENT_TIMESTAMP. + """ + sql = f""" + SELECT data, expires_at FROM {self._table_name} + WHERE session_id = $1 + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP) + """ + + async with self._config.provide_connection() as conn: + row = await conn.fetchrow(sql, key) + + if row is None: + return None + + if renew_for is not None and row["expires_at"] is not None: + new_expires_at = self._calculate_expires_at(renew_for) + if new_expires_at is not None: + update_sql = f""" + UPDATE {self._table_name} + SET expires_at = $1, updated_at = CURRENT_TIMESTAMP + WHERE session_id = $2 + """ + await conn.execute(update_sql, new_expires_at, key) + + return bytes(row["data"]) + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value. + + Args: + key: Session ID. + value: Session data. + expires_in: Time until expiration. + + Notes: + Uses EXCLUDED to reference the proposed insert values in ON CONFLICT. + Updates updated_at timestamp on every write for audit trail. + """ + data = self._value_to_bytes(value) + expires_at = self._calculate_expires_at(expires_in) + + sql = f""" + INSERT INTO {self._table_name} (session_id, data, expires_at) + VALUES ($1, $2, $3) + ON CONFLICT (session_id) + DO UPDATE SET + data = EXCLUDED.data, + expires_at = EXCLUDED.expires_at, + updated_at = CURRENT_TIMESTAMP + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, key, data, expires_at) + + async def delete(self, key: str) -> None: + """Delete a session by key. + + Args: + key: Session ID to delete. + """ + sql = f"DELETE FROM {self._table_name} WHERE session_id = $1" + + async with self._config.provide_connection() as conn: + await conn.execute(sql, key) + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + sql = f"DELETE FROM {self._table_name}" + + async with self._config.provide_connection() as conn: + await conn.execute(sql) + logger.debug("Deleted all sessions from table: %s", self._table_name) + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired. + + Args: + key: Session ID to check. + + Returns: + True if the session exists and is not expired. + + Notes: + Uses CURRENT_TIMESTAMP for consistency with get() method. + """ + sql = f""" + SELECT 1 FROM {self._table_name} + WHERE session_id = $1 + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP) + """ + + async with self._config.provide_connection() as conn: + result = await conn.fetchval(sql, key) + return result is not None + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires. + + Args: + key: Session ID to check. + + Returns: + Seconds until expiration, or None if no expiry or key doesn't exist. + """ + sql = f""" + SELECT expires_at FROM {self._table_name} + WHERE session_id = $1 + """ + + async with self._config.provide_connection() as conn: + expires_at = await conn.fetchval(sql, key) + + if expires_at is None: + return None + + now = datetime.now(timezone.utc) + if expires_at <= now: + return 0 + + delta = expires_at - now + return int(delta.total_seconds()) + + async def delete_expired(self) -> int: + """Delete all expired sessions. + + Returns: + Number of sessions deleted. + + Notes: + Uses CURRENT_TIMESTAMP for consistency. + For very large tables (10M+ rows), consider batching deletes + to avoid holding locks too long. + """ + sql = f"DELETE FROM {self._table_name} WHERE expires_at <= CURRENT_TIMESTAMP" + + async with self._config.provide_connection() as conn: + result = await conn.execute(sql) + count = int(result.split()[-1]) + if count > 0: + logger.debug("Cleaned up %d expired sessions", count) + return count diff --git a/sqlspec/adapters/bigquery/config.py b/sqlspec/adapters/bigquery/config.py index e4c3ffb0..beb5bb96 100644 --- a/sqlspec/adapters/bigquery/config.py +++ b/sqlspec/adapters/bigquery/config.py @@ -95,6 +95,7 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "BigQueryDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, + extension_config: "dict[str, dict[str, Any]] | None" = None, ) -> None: """Initialize BigQuery configuration. @@ -104,6 +105,7 @@ def __init__( statement_config: Statement configuration override driver_features: BigQuery-specific driver features bind_key: Optional unique identifier for this configuration + extension_config: Extension-specific configuration (e.g., Litestar plugin settings) """ self.connection_config: dict[str, Any] = dict(connection_config) if connection_config else {} @@ -127,6 +129,7 @@ def __init__( statement_config=statement_config, driver_features=self.driver_features, bind_key=bind_key, + extension_config=extension_config, ) def _setup_default_job_config(self) -> None: diff --git a/sqlspec/adapters/bigquery/litestar/__init__.py b/sqlspec/adapters/bigquery/litestar/__init__.py new file mode 100644 index 00000000..d19456eb --- /dev/null +++ b/sqlspec/adapters/bigquery/litestar/__init__.py @@ -0,0 +1,5 @@ +"""BigQuery Litestar integration.""" + +from sqlspec.adapters.bigquery.litestar.store import BigQueryStore + +__all__ = ("BigQueryStore",) diff --git a/sqlspec/adapters/bigquery/litestar/store.py b/sqlspec/adapters/bigquery/litestar/store.py new file mode 100644 index 00000000..11c2edad --- /dev/null +++ b/sqlspec/adapters/bigquery/litestar/store.py @@ -0,0 +1,326 @@ +"""BigQuery session store for Litestar integration.""" + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from sqlspec.extensions.litestar.store import BaseSQLSpecStore +from sqlspec.utils.logging import get_logger +from sqlspec.utils.sync_tools import async_ + +if TYPE_CHECKING: + from sqlspec.adapters.bigquery.config import BigQueryConfig + +logger = get_logger("adapters.bigquery.litestar.store") + +__all__ = ("BigQueryStore",) + + +class BigQueryStore(BaseSQLSpecStore["BigQueryConfig"]): + """BigQuery session store using synchronous BigQuery driver. + + Implements server-side session storage for Litestar using Google BigQuery. + Uses Litestar's sync_to_thread utility to provide an async interface + compatible with the Store protocol. + + Provides efficient session management with: + - Sync operations wrapped for async compatibility + - MERGE for UPSERT functionality + - Native TIMESTAMP type support + - Automatic expiration handling + - Efficient cleanup of expired sessions + - Table clustering on session_id for optimized lookups + + Note: + BigQuery is designed for analytical (OLAP) workloads and scales to petabytes. + For typical session store workloads, clustering by session_id provides good + performance. Consider partitioning by created_at if session volume exceeds + millions of rows per day. + + Args: + config: BigQueryConfig instance. + table_name: Name of the session table. Defaults to "litestar_session". + + Example: + from sqlspec.adapters.bigquery import BigQueryConfig + from sqlspec.adapters.bigquery.litestar.store import BigQueryStore + + config = BigQueryConfig(connection_config={"project": "my-project"}) + store = BigQueryStore(config) + await store.create_table() + """ + + __slots__ = () + + def __init__(self, config: "BigQueryConfig", table_name: str = "litestar_session") -> None: + """Initialize BigQuery session store. + + Args: + config: BigQueryConfig instance. + table_name: Name of the session table. + """ + super().__init__(config, table_name) + + def _get_create_table_sql(self) -> str: + """Get BigQuery CREATE TABLE SQL with optimized schema. + + Returns: + SQL statement to create the sessions table with clustering. + + Notes: + - Uses TIMESTAMP for timezone-aware expiration timestamps + - BYTES for binary session data storage + - Clustered by session_id for efficient lookups + - No indexes needed - BigQuery uses columnar storage + - Table name is internally controlled, not user input + """ + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id STRING NOT NULL, + data BYTES NOT NULL, + expires_at TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP() + ) + CLUSTER BY session_id + """ + + def _get_drop_table_sql(self) -> "list[str]": + """Get BigQuery DROP TABLE SQL statements. + + Returns: + List containing DROP TABLE statement. + + Notes: + BigQuery doesn't have separate indexes to drop. + """ + return [f"DROP TABLE IF EXISTS {self._table_name}"] + + def _datetime_to_timestamp(self, dt: "datetime | None") -> "datetime | None": + """Convert datetime to BigQuery TIMESTAMP. + + Args: + dt: Datetime to convert (must be UTC-aware). + + Returns: + UTC datetime object, or None if dt is None. + + Notes: + BigQuery TIMESTAMP type expects UTC datetime objects. + The BigQuery client library handles the conversion. + """ + if dt is None: + return None + if dt.tzinfo is None: + return dt.replace(tzinfo=timezone.utc) + return dt + + def _timestamp_to_datetime(self, ts: "datetime | None") -> "datetime | None": + """Convert TIMESTAMP back to datetime. + + Args: + ts: Datetime object from BigQuery. + + Returns: + UTC-aware datetime, or None if ts is None. + """ + if ts is None: + return None + if ts.tzinfo is None: + return ts.replace(tzinfo=timezone.utc) + return ts + + def _create_table(self) -> None: + """Synchronous implementation of create_table.""" + sql = self._get_create_table_sql() + with self._config.provide_session() as driver: + driver.execute(sql) + logger.debug("Created session table: %s", self._table_name) + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + await async_(self._create_table)() + + def _get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Synchronous implementation of get.""" + sql = f""" + SELECT data, expires_at FROM {self._table_name} + WHERE session_id = @session_id + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP()) + """ + + with self._config.provide_session() as driver: + result = driver.select_one(sql, {"session_id": key}) + + if result is None: + return None + + data = result.get("data") + expires_at = result.get("expires_at") + + if renew_for is not None and expires_at is not None: + new_expires_at = self._calculate_expires_at(renew_for) + new_expires_at_ts = self._datetime_to_timestamp(new_expires_at) + if new_expires_at_ts is not None: + update_sql = f""" + UPDATE {self._table_name} + SET expires_at = @expires_at + WHERE session_id = @session_id + """ + driver.execute(update_sql, {"expires_at": new_expires_at_ts, "session_id": key}) + + return bytes(data) if data is not None else None + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key. + + Args: + key: Session ID to retrieve. + renew_for: If given, renew the expiry time for this duration. + + Returns: + Session data as bytes if found and not expired, None otherwise. + """ + return await async_(self._get)(key, renew_for) + + def _set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Synchronous implementation of set. + + Notes: + Uses MERGE for UPSERT functionality in BigQuery. + BigQuery requires source data to come from a table or inline VALUES. + """ + data = self._value_to_bytes(value) + expires_at = self._calculate_expires_at(expires_in) + expires_at_ts = self._datetime_to_timestamp(expires_at) + + sql = f""" + MERGE {self._table_name} AS target + USING (SELECT @session_id AS session_id, @data AS data, @expires_at AS expires_at) AS source + ON target.session_id = source.session_id + WHEN MATCHED THEN + UPDATE SET data = source.data, expires_at = source.expires_at + WHEN NOT MATCHED THEN + INSERT (session_id, data, expires_at, created_at) + VALUES (source.session_id, source.data, source.expires_at, CURRENT_TIMESTAMP()) + """ + + with self._config.provide_session() as driver: + driver.execute(sql, {"session_id": key, "data": data, "expires_at": expires_at_ts}) + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value. + + Args: + key: Session ID. + value: Session data. + expires_in: Time until expiration. + """ + await async_(self._set)(key, value, expires_in) + + def _delete(self, key: str) -> None: + """Synchronous implementation of delete.""" + sql = f"DELETE FROM {self._table_name} WHERE session_id = @session_id" + + with self._config.provide_session() as driver: + driver.execute(sql, {"session_id": key}) + + async def delete(self, key: str) -> None: + """Delete a session by key. + + Args: + key: Session ID to delete. + """ + await async_(self._delete)(key) + + def _delete_all(self) -> None: + """Synchronous implementation of delete_all.""" + sql = f"DELETE FROM {self._table_name} WHERE TRUE" + + with self._config.provide_session() as driver: + driver.execute(sql) + logger.debug("Deleted all sessions from table: %s", self._table_name) + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + await async_(self._delete_all)() + + def _exists(self, key: str) -> bool: + """Synchronous implementation of exists.""" + sql = f""" + SELECT 1 FROM {self._table_name} + WHERE session_id = @session_id + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP()) + LIMIT 1 + """ + + with self._config.provide_session() as driver: + result = driver.select_one(sql, {"session_id": key}) + return result is not None + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired. + + Args: + key: Session ID to check. + + Returns: + True if the session exists and is not expired. + """ + return await async_(self._exists)(key) + + def _expires_in(self, key: str) -> "int | None": + """Synchronous implementation of expires_in.""" + sql = f""" + SELECT expires_at FROM {self._table_name} + WHERE session_id = @session_id + """ + + with self._config.provide_session() as driver: + result = driver.select_one(sql, {"session_id": key}) + + if result is None: + return None + + expires_at = result.get("expires_at") + if expires_at is None: + return None + + expires_at_dt = self._timestamp_to_datetime(expires_at) + if expires_at_dt is None: + return None + + now = datetime.now(timezone.utc) + if expires_at_dt <= now: + return 0 + + delta = expires_at_dt - now + return int(delta.total_seconds()) + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires. + + Args: + key: Session ID to check. + + Returns: + Seconds until expiration, or None if no expiry or key doesn't exist. + """ + return await async_(self._expires_in)(key) + + def _delete_expired(self) -> int: + """Synchronous implementation of delete_expired.""" + sql = f"DELETE FROM {self._table_name} WHERE expires_at <= CURRENT_TIMESTAMP()" + + with self._config.provide_session() as driver: + result = driver.execute(sql) + count = result.get_affected_count() + if count > 0: + logger.debug("Cleaned up %d expired sessions", count) + return count + + async def delete_expired(self) -> int: + """Delete all expired sessions. + + Returns: + Number of sessions deleted. + """ + return await async_(self._delete_expired)() diff --git a/sqlspec/adapters/duckdb/config.py b/sqlspec/adapters/duckdb/config.py index faffa3ee..e1673b5d 100644 --- a/sqlspec/adapters/duckdb/config.py +++ b/sqlspec/adapters/duckdb/config.py @@ -150,8 +150,19 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "DuckDBDriverFeatures | dict[str, Any] | None" = None, bind_key: "str | None" = None, + extension_config: "dict[str, dict[str, Any]] | None" = None, ) -> None: - """Initialize DuckDB configuration.""" + """Initialize DuckDB configuration. + + Args: + pool_config: Pool configuration parameters + pool_instance: Pre-created pool instance + migration_config: Migration configuration + statement_config: Statement configuration override + driver_features: DuckDB-specific driver features + bind_key: Optional unique identifier for this configuration + extension_config: Extension-specific configuration (e.g., Litestar plugin settings) + """ if pool_config is None: pool_config = {} if "database" not in pool_config: @@ -167,6 +178,7 @@ def __init__( migration_config=migration_config, statement_config=statement_config or duckdb_statement_config, driver_features=cast("dict[str, Any]", driver_features), + extension_config=extension_config, ) def _get_connection_config_dict(self) -> "dict[str, Any]": diff --git a/sqlspec/adapters/duckdb/litestar/__init__.py b/sqlspec/adapters/duckdb/litestar/__init__.py new file mode 100644 index 00000000..c6b17c01 --- /dev/null +++ b/sqlspec/adapters/duckdb/litestar/__init__.py @@ -0,0 +1,5 @@ +"""Litestar integration for DuckDB adapter.""" + +from sqlspec.adapters.duckdb.litestar.store import DuckdbStore + +__all__ = ("DuckdbStore",) diff --git a/sqlspec/adapters/duckdb/litestar/store.py b/sqlspec/adapters/duckdb/litestar/store.py new file mode 100644 index 00000000..b62d07a2 --- /dev/null +++ b/sqlspec/adapters/duckdb/litestar/store.py @@ -0,0 +1,331 @@ +"""DuckDB sync session store for Litestar integration.""" + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from sqlspec.extensions.litestar.store import BaseSQLSpecStore +from sqlspec.utils.logging import get_logger +from sqlspec.utils.sync_tools import async_ + +if TYPE_CHECKING: + from sqlspec.adapters.duckdb.config import DuckDBConfig + +logger = get_logger("adapters.duckdb.litestar.store") + +__all__ = ("DuckdbStore",) + + +class DuckdbStore(BaseSQLSpecStore["DuckDBConfig"]): + """DuckDB session store using synchronous DuckDB driver. + + Implements server-side session storage for Litestar using DuckDB + via the synchronous duckdb driver. Uses Litestar's sync_to_thread + utility to provide an async interface compatible with the Store protocol. + + Provides efficient session management with: + - Sync operations wrapped for async compatibility + - INSERT OR REPLACE for UPSERT functionality + - Native TIMESTAMP type support + - Automatic expiration handling + - Efficient cleanup of expired sessions + - Columnar storage optimized for analytical queries + + Note: + DuckDB is primarily designed for analytical (OLAP) workloads. + For high-concurrency OLTP session stores, consider PostgreSQL adapters. + + Args: + config: DuckDBConfig instance. + table_name: Name of the session table. Defaults to "sessions". + + Example: + from sqlspec.adapters.duckdb import DuckDBConfig + from sqlspec.adapters.duckdb.litestar.store import DuckdbStore + + config = DuckDBConfig() + store = DuckdbStore(config) + await store.create_table() + """ + + __slots__ = () + + def __init__(self, config: "DuckDBConfig", table_name: str = "litestar_session") -> None: + """Initialize DuckDB session store. + + Args: + config: DuckDBConfig instance. + table_name: Name of the session table. + """ + super().__init__(config, table_name) + + def _get_create_table_sql(self) -> str: + """Get DuckDB CREATE TABLE SQL. + + Returns: + SQL statement to create the sessions table with proper indexes. + + Notes: + - Uses TIMESTAMP type for expires_at (DuckDB native datetime type) + - TIMESTAMP supports ISO 8601 format and direct comparisons + - Columnar storage makes this efficient for analytical queries + - DuckDB does not support partial indexes, so full index is created + """ + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id VARCHAR PRIMARY KEY, + data BLOB NOT NULL, + expires_at TIMESTAMP, + created_at TIMESTAMP DEFAULT NOW(), + updated_at TIMESTAMP DEFAULT NOW() + ); + CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at); + """ + + def _get_drop_table_sql(self) -> "list[str]": + """Get DuckDB DROP TABLE SQL statements. + + Returns: + List of SQL statements to drop indexes and table. + """ + return [f"DROP INDEX IF EXISTS idx_{self._table_name}_expires_at", f"DROP TABLE IF EXISTS {self._table_name}"] + + def _datetime_to_timestamp(self, dt: "datetime | None") -> "str | None": + """Convert datetime to ISO 8601 string for DuckDB TIMESTAMP storage. + + Args: + dt: Datetime to convert (must be UTC-aware). + + Returns: + ISO 8601 formatted string, or None if dt is None. + + Notes: + DuckDB's TIMESTAMP type accepts ISO 8601 format strings. + This enables efficient storage and comparison operations. + """ + if dt is None: + return None + return dt.isoformat() + + def _timestamp_to_datetime(self, ts: "str | datetime | None") -> "datetime | None": + """Convert TIMESTAMP string back to datetime. + + Args: + ts: ISO 8601 timestamp string or datetime object. + + Returns: + UTC-aware datetime, or None if ts is None. + """ + if ts is None: + return None + if isinstance(ts, datetime): + if ts.tzinfo is None: + return ts.replace(tzinfo=timezone.utc) + return ts + dt = datetime.fromisoformat(ts) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt + + def _create_table(self) -> None: + """Synchronous implementation of create_table.""" + sql = self._get_create_table_sql() + with self._config.provide_connection() as conn: + conn.execute(sql) + logger.debug("Created session table: %s", self._table_name) + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + await async_(self._create_table)() + + def _get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Synchronous implementation of get.""" + sql = f""" + SELECT data, expires_at FROM {self._table_name} + WHERE session_id = ? + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP) + """ + + with self._config.provide_connection() as conn: + cursor = conn.execute(sql, (key,)) + row = cursor.fetchone() + + if row is None: + return None + + data, expires_at_str = row + + if renew_for is not None and expires_at_str is not None: + new_expires_at = self._calculate_expires_at(renew_for) + new_expires_at_str = self._datetime_to_timestamp(new_expires_at) + if new_expires_at_str is not None: + update_sql = f""" + UPDATE {self._table_name} + SET expires_at = ?, updated_at = NOW() + WHERE session_id = ? + """ + conn.execute(update_sql, (new_expires_at_str, key)) + conn.commit() + + return bytes(data) + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key. + + Args: + key: Session ID to retrieve. + renew_for: If given, renew the expiry time for this duration. + + Returns: + Session data as bytes if found and not expired, None otherwise. + """ + return await async_(self._get)(key, renew_for) + + def _set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Synchronous implementation of set. + + Notes: + Stores expires_at as TIMESTAMP (ISO 8601 string) for DuckDB native support. + Uses INSERT ON CONFLICT instead of INSERT OR REPLACE to ensure all columns + are properly updated. created_at uses DEFAULT on insert, updated_at gets + current timestamp on both insert and update. + """ + data = self._value_to_bytes(value) + expires_at = self._calculate_expires_at(expires_in) + expires_at_str = self._datetime_to_timestamp(expires_at) + + sql = f""" + INSERT INTO {self._table_name} (session_id, data, expires_at) + VALUES (?, ?, ?) + ON CONFLICT (session_id) + DO UPDATE SET + data = EXCLUDED.data, + expires_at = EXCLUDED.expires_at, + updated_at = NOW() + """ + + with self._config.provide_connection() as conn: + conn.execute(sql, (key, data, expires_at_str)) + conn.commit() + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value. + + Args: + key: Session ID. + value: Session data. + expires_in: Time until expiration. + """ + await async_(self._set)(key, value, expires_in) + + def _delete(self, key: str) -> None: + """Synchronous implementation of delete.""" + sql = f"DELETE FROM {self._table_name} WHERE session_id = ?" + + with self._config.provide_connection() as conn: + conn.execute(sql, (key,)) + conn.commit() + + async def delete(self, key: str) -> None: + """Delete a session by key. + + Args: + key: Session ID to delete. + """ + await async_(self._delete)(key) + + def _delete_all(self) -> None: + """Synchronous implementation of delete_all.""" + sql = f"DELETE FROM {self._table_name}" + + with self._config.provide_connection() as conn: + conn.execute(sql) + conn.commit() + logger.debug("Deleted all sessions from table: %s", self._table_name) + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + await async_(self._delete_all)() + + def _exists(self, key: str) -> bool: + """Synchronous implementation of exists.""" + sql = f""" + SELECT 1 FROM {self._table_name} + WHERE session_id = ? + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP) + """ + + with self._config.provide_connection() as conn: + cursor = conn.execute(sql, (key,)) + result = cursor.fetchone() + return result is not None + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired. + + Args: + key: Session ID to check. + + Returns: + True if the session exists and is not expired. + """ + return await async_(self._exists)(key) + + def _expires_in(self, key: str) -> "int | None": + """Synchronous implementation of expires_in.""" + sql = f""" + SELECT expires_at FROM {self._table_name} + WHERE session_id = ? + """ + + with self._config.provide_connection() as conn: + cursor = conn.execute(sql, (key,)) + row = cursor.fetchone() + + if row is None or row[0] is None: + return None + + expires_at_str = row[0] + expires_at = self._timestamp_to_datetime(expires_at_str) + + if expires_at is None: + return None + + now = datetime.now(timezone.utc) + + if expires_at <= now: + return 0 + + delta = expires_at - now + return int(delta.total_seconds()) + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires. + + Args: + key: Session ID to check. + + Returns: + Seconds until expiration, or None if no expiry or key doesn't exist. + """ + return await async_(self._expires_in)(key) + + def _delete_expired(self) -> int: + """Synchronous implementation of delete_expired.""" + sql = f"DELETE FROM {self._table_name} WHERE expires_at <= CURRENT_TIMESTAMP" + + with self._config.provide_connection() as conn: + cursor = conn.execute(sql) + count = cursor.fetchone() + row_count = count[0] if count else 0 + if row_count > 0: + logger.debug("Cleaned up %d expired sessions", row_count) + return row_count + + async def delete_expired(self) -> int: + """Delete all expired sessions. + + Returns: + Number of sessions deleted. + """ + return await async_(self._delete_expired)() diff --git a/sqlspec/adapters/oracledb/config.py b/sqlspec/adapters/oracledb/config.py index 5a2555c2..479b13a9 100644 --- a/sqlspec/adapters/oracledb/config.py +++ b/sqlspec/adapters/oracledb/config.py @@ -95,6 +95,7 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, + extension_config: "dict[str, dict[str, Any]] | None" = None, ) -> None: """Initialize Oracle synchronous configuration. @@ -105,6 +106,7 @@ def __init__( statement_config: Default SQL statement configuration driver_features: Optional driver feature configuration bind_key: Optional unique identifier for this configuration + extension_config: Extension-specific configuration (e.g., Litestar plugin settings) """ processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {} @@ -119,6 +121,7 @@ def __init__( statement_config=statement_config, driver_features=driver_features or {}, bind_key=bind_key, + extension_config=extension_config, ) def _create_pool(self) -> "OracleSyncConnectionPool": @@ -224,6 +227,7 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, + extension_config: "dict[str, dict[str, Any]] | None" = None, ) -> None: """Initialize Oracle asynchronous configuration. @@ -234,6 +238,7 @@ def __init__( statement_config: Default SQL statement configuration driver_features: Optional driver feature configuration bind_key: Optional unique identifier for this configuration + extension_config: Extension-specific configuration (e.g., Litestar plugin settings) """ processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {} @@ -248,6 +253,7 @@ def __init__( statement_config=statement_config or oracledb_statement_config, driver_features=driver_features or {}, bind_key=bind_key, + extension_config=extension_config, ) async def _create_pool(self) -> "OracleAsyncConnectionPool": diff --git a/sqlspec/adapters/oracledb/litestar/__init__.py b/sqlspec/adapters/oracledb/litestar/__init__.py new file mode 100644 index 00000000..94ad596a --- /dev/null +++ b/sqlspec/adapters/oracledb/litestar/__init__.py @@ -0,0 +1,5 @@ +"""Oracle Litestar integration exports.""" + +from sqlspec.adapters.oracledb.litestar.store import OracleAsyncStore, OracleSyncStore + +__all__ = ("OracleAsyncStore", "OracleSyncStore") diff --git a/sqlspec/adapters/oracledb/litestar/store.py b/sqlspec/adapters/oracledb/litestar/store.py new file mode 100644 index 00000000..c830c78e --- /dev/null +++ b/sqlspec/adapters/oracledb/litestar/store.py @@ -0,0 +1,759 @@ +"""Oracle session store for Litestar integration.""" + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from sqlspec.extensions.litestar.store import BaseSQLSpecStore +from sqlspec.utils.logging import get_logger +from sqlspec.utils.sync_tools import async_ + +if TYPE_CHECKING: + from sqlspec.adapters.oracledb.config import OracleAsyncConfig, OracleSyncConfig + +logger = get_logger("adapters.oracledb.litestar.store") + +ORACLE_SMALL_BLOB_LIMIT = 32000 + +__all__ = ("OracleAsyncStore", "OracleSyncStore") + + +class OracleAsyncStore(BaseSQLSpecStore["OracleAsyncConfig"]): + """Oracle session store using async OracleDB driver. + + Implements server-side session storage for Litestar using Oracle Database + via the async python-oracledb driver. Provides efficient session management with: + - Native async Oracle operations + - MERGE statement for atomic UPSERT + - Automatic expiration handling + - Efficient cleanup of expired sessions + - Optional In-Memory Column Store support (requires Oracle Database In-Memory license) + + Args: + config: OracleAsyncConfig instance. + table_name: Name of the session table. Defaults to "litestar_session". + use_in_memory: Enable Oracle Database In-Memory Column Store for faster queries. + Requires Oracle Database In-Memory license (paid feature). Defaults to False. + + Example: + from sqlspec.adapters.oracledb import OracleAsyncConfig + from sqlspec.adapters.oracledb.litestar.store import OracleAsyncStore + + config = OracleAsyncConfig(pool_config={"dsn": "oracle://..."}) + store = OracleAsyncStore(config) + await store.create_table() + + config_inmem = OracleAsyncConfig(pool_config={"dsn": "oracle://..."}) + store_inmem = OracleAsyncStore(config_inmem, use_in_memory=True) + await store_inmem.create_table() + + Notes: + When use_in_memory=True, the table is created with INMEMORY clause for + faster read operations. This requires Oracle Database 12.1.0.2+ with the + Database In-Memory option licensed. If In-Memory is not available, the + table creation will fail with ORA-00439 or ORA-62142. + """ + + __slots__ = ("_use_in_memory",) + + def __init__( + self, config: "OracleAsyncConfig", table_name: str = "litestar_session", use_in_memory: bool = False + ) -> None: + """Initialize Oracle session store. + + Args: + config: OracleAsyncConfig instance. + table_name: Name of the session table. + use_in_memory: Enable In-Memory Column Store (requires license). + """ + super().__init__(config, table_name) + self._use_in_memory = use_in_memory + + def _get_create_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL with optimized schema. + + Returns: + SQL statement to create the sessions table with proper indexes. + + Notes: + - Uses TIMESTAMP WITH TIME ZONE for timezone-aware expiration timestamps + - Partial index WHERE expires_at IS NOT NULL reduces index size/maintenance + - BLOB type for data storage (Oracle native binary type) + - Audit columns (created_at, updated_at) help with debugging + - Table name is internally controlled, not user input (S608 suppressed) + - INMEMORY clause added when use_in_memory=True for faster reads + """ + inmemory_clause = "INMEMORY" if self._use_in_memory else "" + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._table_name} ( + session_id VARCHAR2(255) PRIMARY KEY, + data BLOB NOT NULL, + expires_at TIMESTAMP WITH TIME ZONE, + created_at TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL + ) {inmemory_clause}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + + def _get_drop_table_sql(self) -> "list[str]": + """Get Oracle DROP TABLE SQL with PL/SQL error handling. + + Returns: + List of SQL statements with exception handling for non-existent objects. + """ + return [ + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP INDEX idx_{self._table_name}_expires_at'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """, + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._table_name}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """, + ] + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + sql = self._get_create_table_sql() + conn_context = self._config.provide_connection() + async with conn_context as conn: + cursor = conn.cursor() + await cursor.execute(sql) + await conn.commit() + + index_sql = f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE INDEX idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at)'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + conn_context = self._config.provide_connection() + async with conn_context as conn: + cursor = conn.cursor() + await cursor.execute(index_sql) + await conn.commit() + + logger.debug("Created session table: %s", self._table_name) + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key. + + Args: + key: Session ID to retrieve. + renew_for: If given, renew the expiry time for this duration. + + Returns: + Session data as bytes if found and not expired, None otherwise. + + Notes: + Uses SYSTIMESTAMP for Oracle current timestamp. + The query uses the index for expires_at > SYSTIMESTAMP. + """ + sql = f""" + SELECT data, expires_at FROM {self._table_name} + WHERE session_id = :session_id + AND (expires_at IS NULL OR expires_at > SYSTIMESTAMP) + """ + + conn_context = self._config.provide_connection() + async with conn_context as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"session_id": key}) + row = await cursor.fetchone() + + if row is None: + return None + + data_blob, expires_at = row + + if renew_for is not None and expires_at is not None: + new_expires_at = self._calculate_expires_at(renew_for) + if new_expires_at is not None: + update_sql = f""" + UPDATE {self._table_name} + SET expires_at = :expires_at, updated_at = SYSTIMESTAMP + WHERE session_id = :session_id + """ + await cursor.execute(update_sql, {"expires_at": new_expires_at, "session_id": key}) + await conn.commit() + + try: + blob_data = await data_blob.read() + return bytes(blob_data) if blob_data is not None else bytes(data_blob) + except AttributeError: + return bytes(data_blob) + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value. + + Args: + key: Session ID. + value: Session data. + expires_in: Time until expiration. + + Notes: + Uses MERGE for atomic UPSERT operation in Oracle. + Updates updated_at timestamp on every write for audit trail. + For large BLOBs, uses empty_blob() and then writes data separately. + """ + data = self._value_to_bytes(value) + expires_at = self._calculate_expires_at(expires_in) + + conn_context = self._config.provide_connection() + async with conn_context as conn: + cursor = conn.cursor() + + if len(data) > ORACLE_SMALL_BLOB_LIMIT: + merge_sql = f""" + MERGE INTO {self._table_name} t + USING (SELECT :session_id AS session_id FROM DUAL) s + ON (t.session_id = s.session_id) + WHEN MATCHED THEN + UPDATE SET + data = EMPTY_BLOB(), + expires_at = :expires_at, + updated_at = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (session_id, data, expires_at, created_at, updated_at) + VALUES (:session_id, EMPTY_BLOB(), :expires_at, SYSTIMESTAMP, SYSTIMESTAMP) + """ + await cursor.execute(merge_sql, {"session_id": key, "expires_at": expires_at}) + + select_sql = f""" + SELECT data FROM {self._table_name} + WHERE session_id = :session_id FOR UPDATE + """ + await cursor.execute(select_sql, {"session_id": key}) + row = await cursor.fetchone() + if row: + blob = row[0] + await blob.write(data) + + await conn.commit() + else: + sql = f""" + MERGE INTO {self._table_name} t + USING (SELECT :session_id AS session_id FROM DUAL) s + ON (t.session_id = s.session_id) + WHEN MATCHED THEN + UPDATE SET + data = :data, + expires_at = :expires_at, + updated_at = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (session_id, data, expires_at, created_at, updated_at) + VALUES (:session_id, :data, :expires_at, SYSTIMESTAMP, SYSTIMESTAMP) + """ + await cursor.execute(sql, {"session_id": key, "data": data, "expires_at": expires_at}) + await conn.commit() + + async def delete(self, key: str) -> None: + """Delete a session by key. + + Args: + key: Session ID to delete. + """ + sql = f"DELETE FROM {self._table_name} WHERE session_id = :session_id" + + conn_context = self._config.provide_connection() + async with conn_context as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"session_id": key}) + await conn.commit() + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + sql = f"DELETE FROM {self._table_name}" + + conn_context = self._config.provide_connection() + async with conn_context as conn: + cursor = conn.cursor() + await cursor.execute(sql) + await conn.commit() + logger.debug("Deleted all sessions from table: %s", self._table_name) + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired. + + Args: + key: Session ID to check. + + Returns: + True if the session exists and is not expired. + + Notes: + Uses SYSTIMESTAMP for consistency with get() method. + """ + sql = f""" + SELECT 1 FROM {self._table_name} + WHERE session_id = :session_id + AND (expires_at IS NULL OR expires_at > SYSTIMESTAMP) + """ + + conn_context = self._config.provide_connection() + async with conn_context as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"session_id": key}) + result = await cursor.fetchone() + return result is not None + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires. + + Args: + key: Session ID to check. + + Returns: + Seconds until expiration, or None if no expiry or key doesn't exist. + """ + sql = f""" + SELECT expires_at FROM {self._table_name} + WHERE session_id = :session_id + """ + + conn_context = self._config.provide_connection() + async with conn_context as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"session_id": key}) + row = await cursor.fetchone() + + if row is None or row[0] is None: + return None + + expires_at = row[0] + + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + + now = datetime.now(timezone.utc) + + if expires_at <= now: + return 0 + + delta = expires_at - now + return int(delta.total_seconds()) + + async def delete_expired(self) -> int: + """Delete all expired sessions. + + Returns: + Number of sessions deleted. + + Notes: + Uses SYSTIMESTAMP for consistency. + Oracle automatically commits DDL, so we explicitly commit for DML. + """ + sql = f"DELETE FROM {self._table_name} WHERE expires_at <= SYSTIMESTAMP" + + conn_context = self._config.provide_connection() + async with conn_context as conn: + cursor = conn.cursor() + await cursor.execute(sql) + count = cursor.rowcount if cursor.rowcount is not None else 0 + await conn.commit() + if count > 0: + logger.debug("Cleaned up %d expired sessions", count) + return count + + +class OracleSyncStore(BaseSQLSpecStore["OracleSyncConfig"]): + """Oracle session store using sync OracleDB driver. + + Implements server-side session storage for Litestar using Oracle Database + via the synchronous python-oracledb driver. Uses async_() wrapper to provide + an async interface compatible with the Store protocol. + + Provides efficient session management with: + - Sync operations wrapped for async compatibility + - MERGE statement for atomic UPSERT + - Automatic expiration handling + - Efficient cleanup of expired sessions + - Optional In-Memory Column Store support (requires Oracle Database In-Memory license) + + Note: + For high-concurrency applications, consider using OracleAsyncStore instead, + as it provides native async operations without threading overhead. + + Args: + config: OracleSyncConfig instance. + table_name: Name of the session table. Defaults to "litestar_session". + use_in_memory: Enable Oracle Database In-Memory Column Store for faster queries. + Requires Oracle Database In-Memory license (paid feature). Defaults to False. + + Example: + from sqlspec.adapters.oracledb import OracleSyncConfig + from sqlspec.adapters.oracledb.litestar.store import OracleSyncStore + + config = OracleSyncConfig(pool_config={"dsn": "oracle://..."}) + store = OracleSyncStore(config) + await store.create_table() + + Notes: + When use_in_memory=True, the table is created with INMEMORY clause for + faster read operations. This requires Oracle Database 12.1.0.2+ with the + Database In-Memory option licensed. If In-Memory is not available, the + table creation will fail with ORA-00439 or ORA-62142. + """ + + __slots__ = ("_use_in_memory",) + + def __init__( + self, config: "OracleSyncConfig", table_name: str = "litestar_session", use_in_memory: bool = False + ) -> None: + """Initialize Oracle sync session store. + + Args: + config: OracleSyncConfig instance. + table_name: Name of the session table. + use_in_memory: Enable In-Memory Column Store (requires license). + """ + super().__init__(config, table_name) + self._use_in_memory = use_in_memory + + def _get_create_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL with optimized schema. + + Returns: + SQL statement to create the sessions table with proper indexes. + + Notes: + - Uses TIMESTAMP WITH TIME ZONE for timezone-aware expiration timestamps + - Partial index WHERE expires_at IS NOT NULL reduces index size/maintenance + - BLOB type for data storage (Oracle native binary type) + - Audit columns (created_at, updated_at) help with debugging + - Table name is internally controlled, not user input (S608 suppressed) + - INMEMORY clause added when use_in_memory=True for faster reads + """ + inmemory_clause = "INMEMORY" if self._use_in_memory else "" + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._table_name} ( + session_id VARCHAR2(255) PRIMARY KEY, + data BLOB NOT NULL, + expires_at TIMESTAMP WITH TIME ZONE, + created_at TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL + ) {inmemory_clause}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + + def _get_drop_table_sql(self) -> "list[str]": + """Get Oracle DROP TABLE SQL with PL/SQL error handling. + + Returns: + List of SQL statements with exception handling for non-existent objects. + """ + return [ + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP INDEX idx_{self._table_name}_expires_at'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """, + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._table_name}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """, + ] + + def _create_table(self) -> None: + """Synchronous implementation of create_table.""" + sql = self._get_create_table_sql() + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql) + conn.commit() + + index_sql = f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE INDEX idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at)'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(index_sql) + conn.commit() + + logger.debug("Created session table: %s", self._table_name) + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + await async_(self._create_table)() + + def _get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Synchronous implementation of get. + + Notes: + Uses SYSTIMESTAMP for Oracle current timestamp. + """ + sql = f""" + SELECT data, expires_at FROM {self._table_name} + WHERE session_id = :session_id + AND (expires_at IS NULL OR expires_at > SYSTIMESTAMP) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"session_id": key}) + row = cursor.fetchone() + + if row is None: + return None + + data_blob, expires_at = row + + if renew_for is not None and expires_at is not None: + new_expires_at = self._calculate_expires_at(renew_for) + if new_expires_at is not None: + update_sql = f""" + UPDATE {self._table_name} + SET expires_at = :expires_at, updated_at = SYSTIMESTAMP + WHERE session_id = :session_id + """ + cursor.execute(update_sql, {"expires_at": new_expires_at, "session_id": key}) + conn.commit() + + try: + if hasattr(data_blob, "read"): + blob_data = data_blob.read() + return bytes(blob_data) if blob_data is not None else bytes(data_blob) + return bytes(data_blob) + except AttributeError: + return bytes(data_blob) + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key. + + Args: + key: Session ID to retrieve. + renew_for: If given, renew the expiry time for this duration. + + Returns: + Session data as bytes if found and not expired, None otherwise. + """ + return await async_(self._get)(key, renew_for) + + def _set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Synchronous implementation of set. + + Notes: + Uses MERGE for atomic UPSERT operation in Oracle. + """ + data = self._value_to_bytes(value) + expires_at = self._calculate_expires_at(expires_in) + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + + if len(data) > ORACLE_SMALL_BLOB_LIMIT: + merge_sql = f""" + MERGE INTO {self._table_name} t + USING (SELECT :session_id AS session_id FROM DUAL) s + ON (t.session_id = s.session_id) + WHEN MATCHED THEN + UPDATE SET + data = EMPTY_BLOB(), + expires_at = :expires_at, + updated_at = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (session_id, data, expires_at, created_at, updated_at) + VALUES (:session_id, EMPTY_BLOB(), :expires_at, SYSTIMESTAMP, SYSTIMESTAMP) + """ + cursor.execute(merge_sql, {"session_id": key, "expires_at": expires_at}) + + select_sql = f""" + SELECT data FROM {self._table_name} + WHERE session_id = :session_id FOR UPDATE + """ + cursor.execute(select_sql, {"session_id": key}) + row = cursor.fetchone() + if row: + blob = row[0] + blob.write(data) + + conn.commit() + else: + sql = f""" + MERGE INTO {self._table_name} t + USING (SELECT :session_id AS session_id FROM DUAL) s + ON (t.session_id = s.session_id) + WHEN MATCHED THEN + UPDATE SET + data = :data, + expires_at = :expires_at, + updated_at = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (session_id, data, expires_at, created_at, updated_at) + VALUES (:session_id, :data, :expires_at, SYSTIMESTAMP, SYSTIMESTAMP) + """ + cursor.execute(sql, {"session_id": key, "data": data, "expires_at": expires_at}) + conn.commit() + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value. + + Args: + key: Session ID. + value: Session data. + expires_in: Time until expiration. + """ + await async_(self._set)(key, value, expires_in) + + def _delete(self, key: str) -> None: + """Synchronous implementation of delete.""" + sql = f"DELETE FROM {self._table_name} WHERE session_id = :session_id" + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"session_id": key}) + conn.commit() + + async def delete(self, key: str) -> None: + """Delete a session by key. + + Args: + key: Session ID to delete. + """ + await async_(self._delete)(key) + + def _delete_all(self) -> None: + """Synchronous implementation of delete_all.""" + sql = f"DELETE FROM {self._table_name}" + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql) + conn.commit() + logger.debug("Deleted all sessions from table: %s", self._table_name) + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + await async_(self._delete_all)() + + def _exists(self, key: str) -> bool: + """Synchronous implementation of exists.""" + sql = f""" + SELECT 1 FROM {self._table_name} + WHERE session_id = :session_id + AND (expires_at IS NULL OR expires_at > SYSTIMESTAMP) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"session_id": key}) + result = cursor.fetchone() + return result is not None + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired. + + Args: + key: Session ID to check. + + Returns: + True if the session exists and is not expired. + """ + return await async_(self._exists)(key) + + def _expires_in(self, key: str) -> "int | None": + """Synchronous implementation of expires_in.""" + sql = f""" + SELECT expires_at FROM {self._table_name} + WHERE session_id = :session_id + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"session_id": key}) + row = cursor.fetchone() + + if row is None or row[0] is None: + return None + + expires_at = row[0] + + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + + now = datetime.now(timezone.utc) + + if expires_at <= now: + return 0 + + delta = expires_at - now + return int(delta.total_seconds()) + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires. + + Args: + key: Session ID to check. + + Returns: + Seconds until expiration, or None if no expiry or key doesn't exist. + """ + return await async_(self._expires_in)(key) + + def _delete_expired(self) -> int: + """Synchronous implementation of delete_expired.""" + sql = f"DELETE FROM {self._table_name} WHERE expires_at <= SYSTIMESTAMP" + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql) + count = cursor.rowcount if cursor.rowcount is not None else 0 + conn.commit() + if count > 0: + logger.debug("Cleaned up %d expired sessions", count) + return count + + async def delete_expired(self) -> int: + """Delete all expired sessions. + + Returns: + Number of sessions deleted. + """ + return await async_(self._delete_expired)() diff --git a/sqlspec/adapters/psqlpy/config.py b/sqlspec/adapters/psqlpy/config.py index a331bb22..78eacecd 100644 --- a/sqlspec/adapters/psqlpy/config.py +++ b/sqlspec/adapters/psqlpy/config.py @@ -91,6 +91,7 @@ def __init__( statement_config: StatementConfig | None = None, driver_features: dict[str, Any] | None = None, bind_key: str | None = None, + extension_config: "dict[str, dict[str, Any]] | None" = None, ) -> None: """Initialize Psqlpy configuration. @@ -101,6 +102,7 @@ def __init__( statement_config: SQL statement configuration driver_features: Driver feature configuration bind_key: Optional unique identifier for this configuration + extension_config: Extension-specific configuration (e.g., Litestar plugin settings) """ processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {} if "extra" in processed_pool_config: @@ -113,6 +115,7 @@ def __init__( statement_config=statement_config or psqlpy_statement_config, driver_features=driver_features or {}, bind_key=bind_key, + extension_config=extension_config, ) def _get_pool_config_dict(self) -> dict[str, Any]: diff --git a/sqlspec/adapters/psqlpy/litestar/__init__.py b/sqlspec/adapters/psqlpy/litestar/__init__.py new file mode 100644 index 00000000..c4c5e1e9 --- /dev/null +++ b/sqlspec/adapters/psqlpy/litestar/__init__.py @@ -0,0 +1,5 @@ +"""Litestar integration for psqlpy adapter.""" + +from sqlspec.adapters.psqlpy.litestar.store import PsqlpyStore + +__all__ = ("PsqlpyStore",) diff --git a/sqlspec/adapters/psqlpy/litestar/store.py b/sqlspec/adapters/psqlpy/litestar/store.py new file mode 100644 index 00000000..63b0416d --- /dev/null +++ b/sqlspec/adapters/psqlpy/litestar/store.py @@ -0,0 +1,271 @@ +"""Psqlpy session store for Litestar integration.""" + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from sqlspec.extensions.litestar.store import BaseSQLSpecStore +from sqlspec.utils.logging import get_logger + +if TYPE_CHECKING: + from sqlspec.adapters.psqlpy.config import PsqlpyConfig + +logger = get_logger("adapters.psqlpy.litestar.store") + +__all__ = ("PsqlpyStore",) + + +class PsqlpyStore(BaseSQLSpecStore["PsqlpyConfig"]): + """PostgreSQL session store using Psqlpy driver. + + Implements server-side session storage for Litestar using PostgreSQL + via the Psqlpy driver (Rust-based async driver). Provides efficient + session management with: + - Native async PostgreSQL operations via Rust + - UPSERT support using ON CONFLICT + - Automatic expiration handling + - Efficient cleanup of expired sessions + + Args: + config: PsqlpyConfig instance. + table_name: Name of the session table. Defaults to "sessions". + + Example: + from sqlspec.adapters.psqlpy import PsqlpyConfig + from sqlspec.adapters.psqlpy.litestar.store import PsqlpyStore + + config = PsqlpyConfig(pool_config={"dsn": "postgresql://..."}) + store = PsqlpyStore(config) + await store.create_table() + """ + + __slots__ = () + + def __init__(self, config: "PsqlpyConfig", table_name: str = "litestar_session") -> None: + """Initialize Psqlpy session store. + + Args: + config: PsqlpyConfig instance. + table_name: Name of the session table. + """ + super().__init__(config, table_name) + + def _get_create_table_sql(self) -> str: + """Get PostgreSQL CREATE TABLE SQL with optimized schema. + + Returns: + SQL statement to create the sessions table with proper indexes. + + Notes: + - Uses TIMESTAMPTZ for timezone-aware expiration timestamps + - Partial index WHERE expires_at IS NOT NULL reduces index size/maintenance + - FILLFACTOR 80 leaves space for HOT updates, reducing table bloat + - Audit columns (created_at, updated_at) help with debugging + - Table name is internally controlled, not user input (S608 suppressed) + """ + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id TEXT PRIMARY KEY, + data BYTEA NOT NULL, + expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at) WHERE expires_at IS NOT NULL; + + ALTER TABLE {self._table_name} SET ( + autovacuum_vacuum_scale_factor = 0.05, + autovacuum_analyze_scale_factor = 0.02 + ); + """ + + def _get_drop_table_sql(self) -> "list[str]": + """Get PostgreSQL DROP TABLE SQL statements. + + Returns: + List of SQL statements to drop indexes and table. + """ + return [f"DROP INDEX IF EXISTS idx_{self._table_name}_expires_at", f"DROP TABLE IF EXISTS {self._table_name}"] + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + sql = self._get_create_table_sql() + async with self._config.provide_connection() as conn: + await conn.execute_batch(sql) + logger.debug("Created session table: %s", self._table_name) + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key. + + Args: + key: Session ID to retrieve. + renew_for: If given, renew the expiry time for this duration. + + Returns: + Session data as bytes if found and not expired, None otherwise. + + Notes: + Uses CURRENT_TIMESTAMP instead of NOW() for SQL standard compliance. + The query planner can use the partial index for expires_at > CURRENT_TIMESTAMP. + """ + sql = f""" + SELECT data, expires_at FROM {self._table_name} + WHERE session_id = $1 + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP) + """ + + async with self._config.provide_connection() as conn: + query_result = await conn.fetch(sql, [key]) + rows = query_result.result() + + if not rows: + return None + + row = rows[0] + + if renew_for is not None and row["expires_at"] is not None: + new_expires_at = self._calculate_expires_at(renew_for) + if new_expires_at is not None: + update_sql = f""" + UPDATE {self._table_name} + SET expires_at = $1, updated_at = CURRENT_TIMESTAMP + WHERE session_id = $2 + """ + await conn.execute(update_sql, [new_expires_at, key]) + + return bytes(row["data"]) + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value. + + Args: + key: Session ID. + value: Session data. + expires_in: Time until expiration. + + Notes: + Uses EXCLUDED to reference the proposed insert values in ON CONFLICT. + Updates updated_at timestamp on every write for audit trail. + """ + data = self._value_to_bytes(value) + expires_at = self._calculate_expires_at(expires_in) + + sql = f""" + INSERT INTO {self._table_name} (session_id, data, expires_at) + VALUES ($1, $2, $3) + ON CONFLICT (session_id) + DO UPDATE SET + data = EXCLUDED.data, + expires_at = EXCLUDED.expires_at, + updated_at = CURRENT_TIMESTAMP + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, [key, data, expires_at]) + + async def delete(self, key: str) -> None: + """Delete a session by key. + + Args: + key: Session ID to delete. + """ + sql = f"DELETE FROM {self._table_name} WHERE session_id = $1" + + async with self._config.provide_connection() as conn: + await conn.execute(sql, [key]) + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + sql = f"DELETE FROM {self._table_name}" + + async with self._config.provide_connection() as conn: + await conn.execute(sql) + logger.debug("Deleted all sessions from table: %s", self._table_name) + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired. + + Args: + key: Session ID to check. + + Returns: + True if the session exists and is not expired. + + Notes: + Uses CURRENT_TIMESTAMP for consistency with get() method. + Uses fetch() instead of fetch_val() to handle zero-row case. + """ + sql = f""" + SELECT 1 FROM {self._table_name} + WHERE session_id = $1 + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP) + """ + + async with self._config.provide_connection() as conn: + query_result = await conn.fetch(sql, [key]) + rows = query_result.result() + return len(rows) > 0 + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires. + + Args: + key: Session ID to check. + + Returns: + Seconds until expiration, or None if no expiry or key doesn't exist. + + Notes: + Uses fetch() to handle the case where the key doesn't exist. + """ + sql = f""" + SELECT expires_at FROM {self._table_name} + WHERE session_id = $1 + """ + + async with self._config.provide_connection() as conn: + query_result = await conn.fetch(sql, [key]) + rows = query_result.result() + + if not rows: + return None + + expires_at = rows[0]["expires_at"] + + if expires_at is None: + return None + + now = datetime.now(timezone.utc) + if expires_at <= now: + return 0 + + delta = expires_at - now + return int(delta.total_seconds()) + + async def delete_expired(self) -> int: + """Delete all expired sessions. + + Returns: + Number of sessions deleted. + + Notes: + Uses CURRENT_TIMESTAMP for consistency. + Uses RETURNING to get deleted row count since psqlpy QueryResult + doesn't expose command tags. + For very large tables (10M+ rows), consider batching deletes + to avoid holding locks too long. + """ + sql = f""" + DELETE FROM {self._table_name} + WHERE expires_at <= CURRENT_TIMESTAMP + RETURNING session_id + """ + + async with self._config.provide_connection() as conn: + query_result = await conn.fetch(sql, []) + rows = query_result.result() + count = len(rows) + if count > 0: + logger.debug("Cleaned up %d expired sessions", count) + return count diff --git a/sqlspec/adapters/psycopg/config.py b/sqlspec/adapters/psycopg/config.py index 84bff575..3ec37841 100644 --- a/sqlspec/adapters/psycopg/config.py +++ b/sqlspec/adapters/psycopg/config.py @@ -3,7 +3,7 @@ import contextlib import logging from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypedDict, cast +from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast from psycopg.rows import dict_row from psycopg_pool import AsyncConnectionPool, ConnectionPool @@ -84,11 +84,12 @@ def __init__( self, *, pool_config: "PsycopgPoolParams | dict[str, Any] | None" = None, - pool_instance: Optional["ConnectionPool"] = None, + pool_instance: "ConnectionPool | None" = None, migration_config: dict[str, Any] | None = None, statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, + extension_config: "dict[str, dict[str, Any]] | None" = None, ) -> None: """Initialize Psycopg synchronous configuration. @@ -99,6 +100,7 @@ def __init__( statement_config: Default SQL statement configuration driver_features: Optional driver feature configuration bind_key: Optional unique identifier for this configuration + extension_config: Extension-specific configuration (e.g., Litestar plugin settings) """ processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {} if "extra" in processed_pool_config: @@ -112,6 +114,7 @@ def __init__( statement_config=statement_config or psycopg_statement_config, driver_features=driver_features or {}, bind_key=bind_key, + extension_config=extension_config, ) def _create_pool(self) -> "ConnectionPool": @@ -274,6 +277,7 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, + extension_config: "dict[str, dict[str, Any]] | None" = None, ) -> None: """Initialize Psycopg asynchronous configuration. @@ -284,6 +288,7 @@ def __init__( statement_config: Default SQL statement configuration driver_features: Optional driver feature configuration bind_key: Optional unique identifier for this configuration + extension_config: Extension-specific configuration (e.g., Litestar plugin settings) """ processed_pool_config: dict[str, Any] = dict(pool_config) if pool_config else {} if "extra" in processed_pool_config: @@ -297,6 +302,7 @@ def __init__( statement_config=statement_config or psycopg_statement_config, driver_features=driver_features or {}, bind_key=bind_key, + extension_config=extension_config, ) async def _create_pool(self) -> "AsyncConnectionPool": diff --git a/sqlspec/adapters/psycopg/litestar/__init__.py b/sqlspec/adapters/psycopg/litestar/__init__.py new file mode 100644 index 00000000..8041fae7 --- /dev/null +++ b/sqlspec/adapters/psycopg/litestar/__init__.py @@ -0,0 +1,5 @@ +"""Litestar integration for Psycopg adapter.""" + +from sqlspec.adapters.psycopg.litestar.store import PsycopgAsyncStore, PsycopgSyncStore + +__all__ = ("PsycopgAsyncStore", "PsycopgSyncStore") diff --git a/sqlspec/adapters/psycopg/litestar/store.py b/sqlspec/adapters/psycopg/litestar/store.py new file mode 100644 index 00000000..ff7eba63 --- /dev/null +++ b/sqlspec/adapters/psycopg/litestar/store.py @@ -0,0 +1,561 @@ +"""Psycopg session stores for Litestar integration. + +Provides both async and sync PostgreSQL session stores using psycopg3. +""" + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from sqlspec.extensions.litestar.store import BaseSQLSpecStore +from sqlspec.utils.logging import get_logger +from sqlspec.utils.sync_tools import async_ + +if TYPE_CHECKING: + from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig, PsycopgSyncConfig + +logger = get_logger("adapters.psycopg.litestar.store") + +__all__ = ("PsycopgAsyncStore", "PsycopgSyncStore") + + +class PsycopgAsyncStore(BaseSQLSpecStore["PsycopgAsyncConfig"]): + """PostgreSQL session store using Psycopg async driver. + + Implements server-side session storage for Litestar using PostgreSQL + via the Psycopg (psycopg3) async driver. Provides efficient session + management with: + - Native async PostgreSQL operations + - UPSERT support using ON CONFLICT + - Automatic expiration handling + - Efficient cleanup of expired sessions + + Args: + config: PsycopgAsyncConfig instance. + table_name: Name of the session table. Defaults to "sessions". + + Example: + from sqlspec.adapters.psycopg import PsycopgAsyncConfig + from sqlspec.adapters.psycopg.litestar.store import PsycopgAsyncStore + + config = PsycopgAsyncConfig(pool_config={"conninfo": "postgresql://..."}) + store = PsycopgAsyncStore(config) + await store.create_table() + """ + + __slots__ = () + + def __init__(self, config: "PsycopgAsyncConfig", table_name: str = "litestar_session") -> None: + """Initialize Psycopg async session store. + + Args: + config: PsycopgAsyncConfig instance. + table_name: Name of the session table. + """ + super().__init__(config, table_name) + + def _get_create_table_sql(self) -> str: + """Get PostgreSQL CREATE TABLE SQL with optimized schema. + + Returns: + SQL statement to create the sessions table with proper indexes. + + Notes: + - Uses TIMESTAMPTZ for timezone-aware expiration timestamps + - Partial index WHERE expires_at IS NOT NULL reduces index size/maintenance + - FILLFACTOR 80 leaves space for HOT updates, reducing table bloat + - Audit columns (created_at, updated_at) help with debugging + - Table name is internally controlled, not user input (S608 suppressed) + """ + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id TEXT PRIMARY KEY, + data BYTEA NOT NULL, + expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at) WHERE expires_at IS NOT NULL; + + ALTER TABLE {self._table_name} SET ( + autovacuum_vacuum_scale_factor = 0.05, + autovacuum_analyze_scale_factor = 0.02 + ); + """ + + def _get_drop_table_sql(self) -> "list[str]": + """Get PostgreSQL DROP TABLE SQL statements. + + Returns: + List of SQL statements to drop indexes and table. + """ + return [f"DROP INDEX IF EXISTS idx_{self._table_name}_expires_at", f"DROP TABLE IF EXISTS {self._table_name}"] + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + sql = self._get_create_table_sql() + conn_context = self._config.provide_connection() + async with conn_context as conn: + async with conn.cursor() as cur: + for statement in sql.strip().split(";"): + statement = statement.strip() + if statement: + await cur.execute(statement.encode()) + await conn.commit() + logger.debug("Created session table: %s", self._table_name) + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key. + + Args: + key: Session ID to retrieve. + renew_for: If given, renew the expiry time for this duration. + + Returns: + Session data as bytes if found and not expired, None otherwise. + + Notes: + Uses CURRENT_TIMESTAMP instead of NOW() for SQL standard compliance. + The query planner can use the partial index for expires_at > CURRENT_TIMESTAMP. + """ + sql = f""" + SELECT data, expires_at FROM {self._table_name} + WHERE session_id = %s + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP) + """ + + conn_context = self._config.provide_connection() + async with conn_context as conn: + async with conn.cursor() as cur: + await cur.execute(sql.encode(), (key,)) + row = await cur.fetchone() + + if row is None: + return None + + if renew_for is not None and row["expires_at"] is not None: + new_expires_at = self._calculate_expires_at(renew_for) + if new_expires_at is not None: + update_sql = f""" + UPDATE {self._table_name} + SET expires_at = %s, updated_at = CURRENT_TIMESTAMP + WHERE session_id = %s + """ + await conn.execute(update_sql.encode(), (new_expires_at, key)) + await conn.commit() + + return bytes(row["data"]) + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value. + + Args: + key: Session ID. + value: Session data. + expires_in: Time until expiration. + + Notes: + Uses EXCLUDED to reference the proposed insert values in ON CONFLICT. + Updates updated_at timestamp on every write for audit trail. + """ + data = self._value_to_bytes(value) + expires_at = self._calculate_expires_at(expires_in) + + sql = f""" + INSERT INTO {self._table_name} (session_id, data, expires_at) + VALUES (%s, %s, %s) + ON CONFLICT (session_id) + DO UPDATE SET + data = EXCLUDED.data, + expires_at = EXCLUDED.expires_at, + updated_at = CURRENT_TIMESTAMP + """ + + conn_context = self._config.provide_connection() + async with conn_context as conn: + await conn.execute(sql.encode(), (key, data, expires_at)) + await conn.commit() + + async def delete(self, key: str) -> None: + """Delete a session by key. + + Args: + key: Session ID to delete. + """ + sql = f"DELETE FROM {self._table_name} WHERE session_id = %s" + + conn_context = self._config.provide_connection() + async with conn_context as conn: + await conn.execute(sql.encode(), (key,)) + await conn.commit() + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + sql = f"DELETE FROM {self._table_name}" + + conn_context = self._config.provide_connection() + async with conn_context as conn: + await conn.execute(sql.encode()) + await conn.commit() + logger.debug("Deleted all sessions from table: %s", self._table_name) + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired. + + Args: + key: Session ID to check. + + Returns: + True if the session exists and is not expired. + + Notes: + Uses CURRENT_TIMESTAMP for consistency with get() method. + """ + sql = f""" + SELECT 1 FROM {self._table_name} + WHERE session_id = %s + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP) + """ + + conn_context = self._config.provide_connection() + async with conn_context as conn, conn.cursor() as cur: + await cur.execute(sql.encode(), (key,)) + result = await cur.fetchone() + return result is not None + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires. + + Args: + key: Session ID to check. + + Returns: + Seconds until expiration, or None if no expiry or key doesn't exist. + """ + sql = f""" + SELECT expires_at FROM {self._table_name} + WHERE session_id = %s + """ + + conn_context = self._config.provide_connection() + async with conn_context as conn: + async with conn.cursor() as cur: + await cur.execute(sql.encode(), (key,)) + row = await cur.fetchone() + + if row is None or row["expires_at"] is None: + return None + + expires_at = row["expires_at"] + now = datetime.now(timezone.utc) + if expires_at <= now: + return 0 + + delta = expires_at - now + return int(delta.total_seconds()) + + async def delete_expired(self) -> int: + """Delete all expired sessions. + + Returns: + Number of sessions deleted. + + Notes: + Uses CURRENT_TIMESTAMP for consistency. + For very large tables (10M+ rows), consider batching deletes + to avoid holding locks too long. + """ + sql = f"DELETE FROM {self._table_name} WHERE expires_at <= CURRENT_TIMESTAMP" + + conn_context = self._config.provide_connection() + async with conn_context as conn, conn.cursor() as cur: + await cur.execute(sql.encode()) + await conn.commit() + count = cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + if count > 0: + logger.debug("Cleaned up %d expired sessions", count) + return count + + +class PsycopgSyncStore(BaseSQLSpecStore["PsycopgSyncConfig"]): + """PostgreSQL session store using Psycopg sync driver. + + Implements server-side session storage for Litestar using PostgreSQL + via the synchronous Psycopg (psycopg3) driver. Uses Litestar's sync_to_thread + utility to provide an async interface compatible with the Store protocol. + + Provides efficient session management with: + - Sync operations wrapped for async compatibility + - UPSERT support using ON CONFLICT + - Automatic expiration handling + - Efficient cleanup of expired sessions + + Note: + For high-concurrency applications, consider using PsycopgAsyncStore instead, + as it provides native async operations without threading overhead. + + Args: + config: PsycopgSyncConfig instance. + table_name: Name of the session table. Defaults to "litestar_session". + + Example: + from sqlspec.adapters.psycopg import PsycopgSyncConfig + from sqlspec.adapters.psycopg.litestar.store import PsycopgSyncStore + + config = PsycopgSyncConfig(pool_config={"conninfo": "postgresql://..."}) + store = PsycopgSyncStore(config) + await store.create_table() + """ + + __slots__ = () + + def __init__(self, config: "PsycopgSyncConfig", table_name: str = "litestar_session") -> None: + """Initialize Psycopg sync session store. + + Args: + config: PsycopgSyncConfig instance. + table_name: Name of the session table. + """ + super().__init__(config, table_name) + + def _get_create_table_sql(self) -> str: + """Get PostgreSQL CREATE TABLE SQL with optimized schema. + + Returns: + SQL statement to create the sessions table with proper indexes. + + Notes: + - Uses TIMESTAMPTZ for timezone-aware expiration timestamps + - Partial index WHERE expires_at IS NOT NULL reduces index size/maintenance + - FILLFACTOR 80 leaves space for HOT updates, reducing table bloat + - Audit columns (created_at, updated_at) help with debugging + - Table name is internally controlled, not user input (S608 suppressed) + """ + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id TEXT PRIMARY KEY, + data BYTEA NOT NULL, + expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at) WHERE expires_at IS NOT NULL; + + ALTER TABLE {self._table_name} SET ( + autovacuum_vacuum_scale_factor = 0.05, + autovacuum_analyze_scale_factor = 0.02 + ); + """ + + def _get_drop_table_sql(self) -> "list[str]": + """Get PostgreSQL DROP TABLE SQL statements. + + Returns: + List of SQL statements to drop indexes and table. + """ + return [f"DROP INDEX IF EXISTS idx_{self._table_name}_expires_at", f"DROP TABLE IF EXISTS {self._table_name}"] + + def _create_table(self) -> None: + """Synchronous implementation of create_table.""" + sql = self._get_create_table_sql() + with self._config.provide_connection() as conn: + with conn.cursor() as cur: + for statement in sql.strip().split(";"): + statement = statement.strip() + if statement: + cur.execute(statement.encode()) + conn.commit() + logger.debug("Created session table: %s", self._table_name) + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + await async_(self._create_table)() + + def _get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Synchronous implementation of get. + + Notes: + Uses CURRENT_TIMESTAMP for SQL standard compliance. + """ + sql = f""" + SELECT data, expires_at FROM {self._table_name} + WHERE session_id = %s + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP) + """ + + with self._config.provide_connection() as conn: + with conn.cursor() as cur: + cur.execute(sql.encode(), (key,)) + row = cur.fetchone() + + if row is None: + return None + + if renew_for is not None and row["expires_at"] is not None: + new_expires_at = self._calculate_expires_at(renew_for) + if new_expires_at is not None: + update_sql = f""" + UPDATE {self._table_name} + SET expires_at = %s, updated_at = CURRENT_TIMESTAMP + WHERE session_id = %s + """ + conn.execute(update_sql.encode(), (new_expires_at, key)) + conn.commit() + + return bytes(row["data"]) + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key. + + Args: + key: Session ID to retrieve. + renew_for: If given, renew the expiry time for this duration. + + Returns: + Session data as bytes if found and not expired, None otherwise. + """ + return await async_(self._get)(key, renew_for) + + def _set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Synchronous implementation of set. + + Notes: + Uses EXCLUDED to reference the proposed insert values in ON CONFLICT. + """ + data = self._value_to_bytes(value) + expires_at = self._calculate_expires_at(expires_in) + + sql = f""" + INSERT INTO {self._table_name} (session_id, data, expires_at) + VALUES (%s, %s, %s) + ON CONFLICT (session_id) + DO UPDATE SET + data = EXCLUDED.data, + expires_at = EXCLUDED.expires_at, + updated_at = CURRENT_TIMESTAMP + """ + + with self._config.provide_connection() as conn: + conn.execute(sql.encode(), (key, data, expires_at)) + conn.commit() + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value. + + Args: + key: Session ID. + value: Session data. + expires_in: Time until expiration. + """ + await async_(self._set)(key, value, expires_in) + + def _delete(self, key: str) -> None: + """Synchronous implementation of delete.""" + sql = f"DELETE FROM {self._table_name} WHERE session_id = %s" + + with self._config.provide_connection() as conn: + conn.execute(sql.encode(), (key,)) + conn.commit() + + async def delete(self, key: str) -> None: + """Delete a session by key. + + Args: + key: Session ID to delete. + """ + await async_(self._delete)(key) + + def _delete_all(self) -> None: + """Synchronous implementation of delete_all.""" + sql = f"DELETE FROM {self._table_name}" + + with self._config.provide_connection() as conn: + conn.execute(sql.encode()) + conn.commit() + logger.debug("Deleted all sessions from table: %s", self._table_name) + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + await async_(self._delete_all)() + + def _exists(self, key: str) -> bool: + """Synchronous implementation of exists.""" + sql = f""" + SELECT 1 FROM {self._table_name} + WHERE session_id = %s + AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP) + """ + + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(sql.encode(), (key,)) + result = cur.fetchone() + return result is not None + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired. + + Args: + key: Session ID to check. + + Returns: + True if the session exists and is not expired. + """ + return await async_(self._exists)(key) + + def _expires_in(self, key: str) -> "int | None": + """Synchronous implementation of expires_in.""" + sql = f""" + SELECT expires_at FROM {self._table_name} + WHERE session_id = %s + """ + + with self._config.provide_connection() as conn: + with conn.cursor() as cur: + cur.execute(sql.encode(), (key,)) + row = cur.fetchone() + + if row is None or row["expires_at"] is None: + return None + + expires_at = row["expires_at"] + now = datetime.now(timezone.utc) + + if expires_at <= now: + return 0 + + delta = expires_at - now + return int(delta.total_seconds()) + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires. + + Args: + key: Session ID to check. + + Returns: + Seconds until expiration, or None if no expiry or key doesn't exist. + """ + return await async_(self._expires_in)(key) + + def _delete_expired(self) -> int: + """Synchronous implementation of delete_expired.""" + sql = f"DELETE FROM {self._table_name} WHERE expires_at <= CURRENT_TIMESTAMP" + + with self._config.provide_connection() as conn, conn.cursor() as cur: + cur.execute(sql.encode()) + conn.commit() + count = cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + if count > 0: + logger.debug("Cleaned up %d expired sessions", count) + return count + + async def delete_expired(self) -> int: + """Delete all expired sessions. + + Returns: + Number of sessions deleted. + """ + return await async_(self._delete_expired)() diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index 524df727..470268a3 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -51,6 +51,7 @@ def __init__( statement_config: "StatementConfig | None" = None, driver_features: "dict[str, Any] | None" = None, bind_key: "str | None" = None, + extension_config: "dict[str, dict[str, Any]] | None" = None, ) -> None: """Initialize SQLite configuration. @@ -61,6 +62,7 @@ def __init__( statement_config: Default SQL statement configuration driver_features: Optional driver feature configuration bind_key: Optional bind key for the configuration + extension_config: Extension-specific configuration (e.g., Litestar plugin settings) """ if pool_config is None: pool_config = {} @@ -84,6 +86,7 @@ def __init__( migration_config=migration_config, statement_config=statement_config or sqlite_statement_config, driver_features=driver_features or {}, + extension_config=extension_config, ) def _get_connection_config_dict(self) -> "dict[str, Any]": diff --git a/sqlspec/adapters/sqlite/litestar/__init__.py b/sqlspec/adapters/sqlite/litestar/__init__.py new file mode 100644 index 00000000..a2e930f4 --- /dev/null +++ b/sqlspec/adapters/sqlite/litestar/__init__.py @@ -0,0 +1,5 @@ +"""Litestar integration for SQLite adapter.""" + +from sqlspec.adapters.sqlite.litestar.store import SQLiteStore + +__all__ = ("SQLiteStore",) diff --git a/sqlspec/adapters/sqlite/litestar/store.py b/sqlspec/adapters/sqlite/litestar/store.py new file mode 100644 index 00000000..c146e911 --- /dev/null +++ b/sqlspec/adapters/sqlite/litestar/store.py @@ -0,0 +1,317 @@ +"""SQLite sync session store for Litestar integration.""" + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from sqlspec.extensions.litestar.store import BaseSQLSpecStore +from sqlspec.utils.logging import get_logger +from sqlspec.utils.sync_tools import async_ + +if TYPE_CHECKING: + from sqlspec.adapters.sqlite.config import SqliteConfig + +logger = get_logger("adapters.sqlite.litestar.store") + +SECONDS_PER_DAY = 86400.0 +JULIAN_EPOCH = 2440587.5 + +__all__ = ("SQLiteStore",) + + +class SQLiteStore(BaseSQLSpecStore["SqliteConfig"]): + """SQLite session store using synchronous SQLite driver. + + Implements server-side session storage for Litestar using SQLite + via the synchronous sqlite3 driver. Uses Litestar's sync_to_thread + utility to provide an async interface compatible with the Store protocol. + + Provides efficient session management with: + - Sync operations wrapped for async compatibility + - INSERT OR REPLACE for UPSERT functionality + - Automatic expiration handling + - Efficient cleanup of expired sessions + + Args: + config: SqliteConfig instance. + table_name: Name of the session table. Defaults to "sessions". + + Example: + from sqlspec.adapters.sqlite import SqliteConfig + from sqlspec.adapters.sqlite.litestar.store import SQLiteStore + + config = SqliteConfig(database=":memory:") + store = SQLiteStore(config) + await store.create_table() + """ + + __slots__ = () + + def __init__(self, config: "SqliteConfig", table_name: str = "litestar_session") -> None: + """Initialize SQLite session store. + + Args: + config: SqliteConfig instance. + table_name: Name of the session table. + """ + super().__init__(config, table_name) + + def _get_create_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL. + + Returns: + SQL statement to create the sessions table with proper indexes. + + Notes: + - Uses REAL type for expires_at (stores Julian Day number) + - Julian Day enables direct comparison with julianday('now') + - Partial index WHERE expires_at IS NOT NULL reduces index size + - This approach ensures the index is actually used by query optimizer + """ + return f""" + CREATE TABLE IF NOT EXISTS {self._table_name} ( + session_id TEXT PRIMARY KEY, + data BLOB NOT NULL, + expires_at REAL + ); + CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at + ON {self._table_name}(expires_at) WHERE expires_at IS NOT NULL; + """ + + def _get_drop_table_sql(self) -> "list[str]": + """Get SQLite DROP TABLE SQL statements. + + Returns: + List of SQL statements to drop indexes and table. + """ + return [f"DROP INDEX IF EXISTS idx_{self._table_name}_expires_at", f"DROP TABLE IF EXISTS {self._table_name}"] + + def _datetime_to_julian(self, dt: "datetime | None") -> "float | None": + """Convert datetime to Julian Day number for SQLite storage. + + Args: + dt: Datetime to convert (must be UTC-aware). + + Returns: + Julian Day number as REAL, or None if dt is None. + + Notes: + Julian Day number is days since November 24, 4714 BCE (proleptic Gregorian). + This enables direct comparison with julianday('now') in SQL queries. + """ + if dt is None: + return None + + epoch = datetime(1970, 1, 1, tzinfo=timezone.utc) + delta_days = (dt - epoch).total_seconds() / SECONDS_PER_DAY + return JULIAN_EPOCH + delta_days + + def _julian_to_datetime(self, julian: "float | None") -> "datetime | None": + """Convert Julian Day number back to datetime. + + Args: + julian: Julian Day number. + + Returns: + UTC-aware datetime, or None if julian is None. + """ + if julian is None: + return None + + days_since_epoch = julian - JULIAN_EPOCH + timestamp = days_since_epoch * SECONDS_PER_DAY + return datetime.fromtimestamp(timestamp, tz=timezone.utc) + + def _create_table(self) -> None: + """Synchronous implementation of create_table.""" + sql = self._get_create_table_sql() + with self._config.provide_connection() as conn: + conn.executescript(sql) + logger.debug("Created session table: %s", self._table_name) + + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + await async_(self._create_table)() + + def _get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Synchronous implementation of get.""" + sql = f""" + SELECT data, expires_at FROM {self._table_name} + WHERE session_id = ? + AND (expires_at IS NULL OR julianday(expires_at) > julianday('now')) + """ + + with self._config.provide_connection() as conn: + cursor = conn.execute(sql, (key,)) + row = cursor.fetchone() + + if row is None: + return None + + data, expires_at_julian = row + + if renew_for is not None and expires_at_julian is not None: + new_expires_at = self._calculate_expires_at(renew_for) + new_expires_at_julian = self._datetime_to_julian(new_expires_at) + if new_expires_at_julian is not None: + update_sql = f""" + UPDATE {self._table_name} + SET expires_at = ? + WHERE session_id = ? + """ + conn.execute(update_sql, (new_expires_at_julian, key)) + conn.commit() + + return bytes(data) + + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key. + + Args: + key: Session ID to retrieve. + renew_for: If given, renew the expiry time for this duration. + + Returns: + Session data as bytes if found and not expired, None otherwise. + """ + return await async_(self._get)(key, renew_for) + + def _set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Synchronous implementation of set. + + Notes: + Stores expires_at as Julian Day number (REAL) for optimal index usage. + """ + data = self._value_to_bytes(value) + expires_at = self._calculate_expires_at(expires_in) + expires_at_julian = self._datetime_to_julian(expires_at) + + sql = f""" + INSERT OR REPLACE INTO {self._table_name} (session_id, data, expires_at) + VALUES (?, ?, ?) + """ + + with self._config.provide_connection() as conn: + conn.execute(sql, (key, data, expires_at_julian)) + conn.commit() + + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value. + + Args: + key: Session ID. + value: Session data. + expires_in: Time until expiration. + """ + await async_(self._set)(key, value, expires_in) + + def _delete(self, key: str) -> None: + """Synchronous implementation of delete.""" + sql = f"DELETE FROM {self._table_name} WHERE session_id = ?" + + with self._config.provide_connection() as conn: + conn.execute(sql, (key,)) + conn.commit() + + async def delete(self, key: str) -> None: + """Delete a session by key. + + Args: + key: Session ID to delete. + """ + await async_(self._delete)(key) + + def _delete_all(self) -> None: + """Synchronous implementation of delete_all.""" + sql = f"DELETE FROM {self._table_name}" + + with self._config.provide_connection() as conn: + conn.execute(sql) + conn.commit() + logger.debug("Deleted all sessions from table: %s", self._table_name) + + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + await async_(self._delete_all)() + + def _exists(self, key: str) -> bool: + """Synchronous implementation of exists.""" + sql = f""" + SELECT 1 FROM {self._table_name} + WHERE session_id = ? + AND (expires_at IS NULL OR julianday(expires_at) > julianday('now')) + """ + + with self._config.provide_connection() as conn: + cursor = conn.execute(sql, (key,)) + result = cursor.fetchone() + return result is not None + + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired. + + Args: + key: Session ID to check. + + Returns: + True if the session exists and is not expired. + """ + return await async_(self._exists)(key) + + def _expires_in(self, key: str) -> "int | None": + """Synchronous implementation of expires_in.""" + sql = f""" + SELECT expires_at FROM {self._table_name} + WHERE session_id = ? + """ + + with self._config.provide_connection() as conn: + cursor = conn.execute(sql, (key,)) + row = cursor.fetchone() + + if row is None or row[0] is None: + return None + + expires_at_julian = row[0] + expires_at = self._julian_to_datetime(expires_at_julian) + + if expires_at is None: + return None + + now = datetime.now(timezone.utc) + + if expires_at <= now: + return 0 + + delta = expires_at - now + return int(delta.total_seconds()) + + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires. + + Args: + key: Session ID to check. + + Returns: + Seconds until expiration, or None if no expiry or key doesn't exist. + """ + return await async_(self._expires_in)(key) + + def _delete_expired(self) -> int: + """Synchronous implementation of delete_expired.""" + sql = f"DELETE FROM {self._table_name} WHERE julianday(expires_at) <= julianday('now')" + + with self._config.provide_connection() as conn: + cursor = conn.execute(sql) + conn.commit() + count = cursor.rowcount + if count > 0: + logger.debug("Cleaned up %d expired sessions", count) + return count + + async def delete_expired(self) -> int: + """Delete all expired sessions. + + Returns: + Number of sessions deleted. + """ + return await async_(self._delete_expired)() diff --git a/sqlspec/adapters/sqlite/pool.py b/sqlspec/adapters/sqlite/pool.py index d6e1154a..310670a2 100644 --- a/sqlspec/adapters/sqlite/pool.py +++ b/sqlspec/adapters/sqlite/pool.py @@ -62,7 +62,7 @@ def _create_connection(self) -> SqliteConnection: if self._enable_optimizations: database = self._connection_parameters.get("database", ":memory:") - is_memory = database == ":memory:" or database.startswith("file::memory:") + is_memory = database == ":memory:" or "mode=memory" in database if not is_memory: connection.execute("PRAGMA journal_mode = DELETE") diff --git a/sqlspec/extensions/litestar/__init__.py b/sqlspec/extensions/litestar/__init__.py index 2d60c576..54fdaaa9 100644 --- a/sqlspec/extensions/litestar/__init__.py +++ b/sqlspec/extensions/litestar/__init__.py @@ -7,12 +7,14 @@ CommitMode, SQLSpecPlugin, ) +from sqlspec.extensions.litestar.store import BaseSQLSpecStore __all__ = ( "DEFAULT_COMMIT_MODE", "DEFAULT_CONNECTION_KEY", "DEFAULT_POOL_KEY", "DEFAULT_SESSION_KEY", + "BaseSQLSpecStore", "CommitMode", "SQLSpecPlugin", "database_group", diff --git a/sqlspec/extensions/litestar/cli.py b/sqlspec/extensions/litestar/cli.py index a57295f7..3b6f8f11 100644 --- a/sqlspec/extensions/litestar/cli.py +++ b/sqlspec/extensions/litestar/cli.py @@ -46,3 +46,51 @@ def database_group(ctx: "click.Context") -> None: add_migration_commands(database_group) + + +def add_sessions_delete_expired_command() -> None: + """Add delete-expired command to Litestar's sessions CLI group.""" + try: + from litestar.cli._utils import console + from litestar.cli.commands.sessions import get_session_backend, sessions_group + except ImportError: + return + + @sessions_group.command("delete-expired") # type: ignore[misc] + @click.option( + "--verbose", is_flag=True, default=False, help="Show detailed information about the cleanup operation" + ) + def delete_expired_sessions_command(app: "Litestar", verbose: bool) -> None: + """Delete expired sessions from the session store. + + This command removes all sessions that have passed their expiration time. + It can be scheduled via cron or systemd timers for automatic maintenance. + + Examples: + litestar sessions delete-expired + litestar sessions delete-expired --verbose + """ + import anyio + + backend = get_session_backend(app) + store = backend.config.get_store_from_app(app) + + if not hasattr(store, "delete_expired"): + console.print(f"[red]{type(store).__name__} does not support deleting expired sessions") + return + + async def _delete_expired() -> int: + return await store.delete_expired() # type: ignore[no-any-return] + + count = anyio.run(_delete_expired) + + if count > 0: + if verbose: + console.print(f"[green]Successfully deleted {count} expired session(s)") + else: + console.print(f"[green]Deleted {count} expired session(s)") + else: + console.print("[yellow]No expired sessions found") + + +add_sessions_delete_expired_command() diff --git a/sqlspec/extensions/litestar/migrations/0001_create_session_table.py b/sqlspec/extensions/litestar/migrations/0001_create_session_table.py new file mode 100644 index 00000000..05772c7c --- /dev/null +++ b/sqlspec/extensions/litestar/migrations/0001_create_session_table.py @@ -0,0 +1,148 @@ +"""Create Litestar session table migration using store DDL definitions.""" + +from typing import TYPE_CHECKING, NoReturn + +from sqlspec.exceptions import SQLSpecError +from sqlspec.utils.logging import get_logger +from sqlspec.utils.module_loader import import_string + +if TYPE_CHECKING: + from sqlspec.extensions.litestar.store import BaseSQLSpecStore + from sqlspec.migrations.context import MigrationContext + +logger = get_logger("migrations.litestar.session") + +__all__ = ("down", "up") + + +def _get_store_class(context: "MigrationContext | None") -> "type[BaseSQLSpecStore]": + """Get the appropriate store class based on the config's module path. + + Args: + context: Migration context containing config. + + Returns: + Store class matching the config's adapter. + + Notes: + Dynamically imports the store class from the config's module path. + For example, AsyncpgConfig at 'sqlspec.adapters.asyncpg.config' + maps to AsyncpgStore at 'sqlspec.adapters.asyncpg.litestar.store.AsyncpgStore'. + """ + if not context or not context.config: + _raise_missing_config() + + config_class = type(context.config) + config_module = config_class.__module__ + config_name = config_class.__name__ + + if not config_module.startswith("sqlspec.adapters."): + _raise_unsupported_config(f"{config_module}.{config_name}") + + adapter_name = config_module.split(".")[2] + store_class_name = config_name.replace("Config", "Store") + + store_path = f"sqlspec.adapters.{adapter_name}.litestar.store.{store_class_name}" + + try: + store_class: type[BaseSQLSpecStore] = import_string(store_path) + except ImportError as e: + _raise_store_import_failed(store_path, e) + + return store_class + + +def _raise_missing_config() -> NoReturn: + """Raise error when migration context has no config. + + Raises: + SQLSpecError: Always raised. + """ + msg = "Migration context must have a config to determine store class" + raise SQLSpecError(msg) + + +def _raise_unsupported_config(config_type: str) -> NoReturn: + """Raise error for unsupported config type. + + Args: + config_type: The unsupported config type name. + + Raises: + SQLSpecError: Always raised with config type info. + """ + msg = f"Unsupported config type for Litestar session migration: {config_type}" + raise SQLSpecError(msg) + + +def _raise_store_import_failed(store_path: str, error: ImportError) -> NoReturn: + """Raise error when store class import fails. + + Args: + store_path: The import path that failed. + error: The original import error. + + Raises: + SQLSpecError: Always raised with import details. + """ + msg = f"Failed to import Litestar store class from {store_path}: {error}" + raise SQLSpecError(msg) from error + + +def _get_table_name(context: "MigrationContext | None") -> str: + """Extract table name from migration context. + + Args: + context: Migration context with extension config. + + Returns: + Table name for the session store. + """ + if context and context.extension_config: + table_name: str = context.extension_config.get("session_table", "litestar_session") + return table_name + return "litestar_session" + + +async def up(context: "MigrationContext | None" = None) -> "list[str]": + """Create the litestar session table using store DDL definitions. + + This migration delegates to the appropriate store class to generate + dialect-specific DDL. The store classes contain the single source of + truth for session table schemas. + + Args: + context: Migration context containing config. + + Returns: + List of SQL statements to execute for upgrade. + """ + table_name = _get_table_name(context) + store_class = _get_store_class(context) + if context is None or context.config is None: + _raise_missing_config() + store = store_class(config=context.config, table_name=table_name) + + return [store._get_create_table_sql()] # pyright: ignore[reportPrivateUsage] + + +async def down(context: "MigrationContext | None" = None) -> "list[str]": + """Drop the litestar session table using store DDL definitions. + + This migration delegates to the appropriate store class to generate + dialect-specific DROP statements. The store classes contain the single + source of truth for session table schemas. + + Args: + context: Migration context containing config. + + Returns: + List of SQL statements to execute for downgrade. + """ + table_name = _get_table_name(context) + store_class = _get_store_class(context) + if context is None or context.config is None: + _raise_missing_config() + store = store_class(config=context.config, table_name=table_name) + + return store._get_drop_table_sql() # pyright: ignore[reportPrivateUsage] diff --git a/sqlspec/extensions/litestar/migrations/__init__.py b/sqlspec/extensions/litestar/migrations/__init__.py new file mode 100644 index 00000000..79167239 --- /dev/null +++ b/sqlspec/extensions/litestar/migrations/__init__.py @@ -0,0 +1,3 @@ +"""Litestar extension migrations for session table creation.""" + +__all__ = () diff --git a/sqlspec/extensions/litestar/plugin.py b/sqlspec/extensions/litestar/plugin.py index fec875ec..2f031f2b 100644 --- a/sqlspec/extensions/litestar/plugin.py +++ b/sqlspec/extensions/litestar/plugin.py @@ -80,7 +80,38 @@ class _PluginConfigState: class SQLSpecPlugin(InitPluginProtocol, CLIPlugin): - """Litestar plugin for SQLSpec database integration.""" + """Litestar plugin for SQLSpec database integration. + + Session Table Migrations: + The Litestar extension includes migrations for creating session storage tables. + To include these migrations in your database migration workflow, add 'litestar' + to the include_extensions list in your migration configuration. + + Example: + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/db"}, + extension_config={ + "litestar": { + "connection_key": "db_connection", + "commit_mode": "autocommit" + } + }, + migration_config={ + "script_location": "migrations", + "include_extensions": ["litestar"], + } + ) + + The session table migration will automatically use the appropriate column types + for your database dialect (JSONB for PostgreSQL, JSON for MySQL, TEXT for SQLite). + Customize the table name via extension_config: + + migration_config={ + "include_extensions": [ + {"name": "litestar", "session_table": "custom_sessions"} + ] + } + """ __slots__ = ("_plugin_configs", "_sqlspec") @@ -270,7 +301,7 @@ def get_annotation( The annotation for the configuration. """ for state in self._plugin_configs: - if key in (state.config, state.annotation) or key in {state.connection_key, state.pool_key}: + if key in {state.config, state.annotation} or key in {state.connection_key, state.pool_key}: return cast( "type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]", state.annotation, @@ -314,7 +345,7 @@ def get_config( return cast("DatabaseConfigProtocol[Any, Any, Any]", state.config) # type: ignore[redundant-cast] for state in self._plugin_configs: - if name in (state.config, state.annotation): + if name in {state.config, state.annotation}: return cast("DatabaseConfigProtocol[Any, Any, Any]", state.config) # type: ignore[redundant-cast] msg = f"No database configuration found for name '{name}'. Available keys: {self._get_available_keys()}" @@ -411,7 +442,7 @@ def _get_plugin_state( return state for state in self._plugin_configs: - if key in (state.config, state.annotation): + if key in {state.config, state.annotation}: return state self._raise_config_not_found(key) diff --git a/sqlspec/extensions/litestar/store.py b/sqlspec/extensions/litestar/store.py new file mode 100644 index 00000000..057e52a3 --- /dev/null +++ b/sqlspec/extensions/litestar/store.py @@ -0,0 +1,244 @@ +"""Base session store classes for Litestar integration.""" + +import re +from abc import ABC, abstractmethod +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Final, Generic, TypeVar + +from sqlspec.utils.logging import get_logger + +if TYPE_CHECKING: + from types import TracebackType + + +ConfigT = TypeVar("ConfigT") + + +logger = get_logger("extensions.litestar.store") + +__all__ = ("BaseSQLSpecStore",) + +VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") +MAX_TABLE_NAME_LENGTH: Final = 63 + + +class BaseSQLSpecStore(ABC, Generic[ConfigT]): + """Base class for SQLSpec-backed Litestar session stores. + + Implements the litestar.stores.base.Store protocol for server-side session + storage using SQLSpec database adapters. + + This abstract base class provides common functionality for all database-specific + store implementations including: + - Connection management via SQLSpec configs + - Session expiration calculation + - Table creation utilities + + Subclasses must implement dialect-specific SQL queries. + + Args: + config: SQLSpec database configuration (async or sync). + table_name: Name of the session table. Defaults to "litestar_session". + + Example: + from sqlspec.adapters.asyncpg import AsyncpgConfig + from sqlspec.adapters.asyncpg.litestar.store import AsyncpgStore + + config = AsyncpgConfig(pool_config={"dsn": "postgresql://..."}) + store = AsyncpgStore(config) + await store.create_table() + """ + + __slots__ = ("_config", "_table_name") + + def __init__(self, config: ConfigT, table_name: str = "litestar_session") -> None: + """Initialize the session store. + + Args: + config: SQLSpec database configuration. + table_name: Name of the session table. + """ + self._validate_table_name(table_name) + self._config = config + self._table_name = table_name + + @property + def config(self) -> ConfigT: + """Return the database configuration.""" + return self._config + + @property + def table_name(self) -> str: + """Return the session table name.""" + return self._table_name + + @abstractmethod + async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None": + """Get a session value by key. + + Args: + key: Session ID to retrieve. + renew_for: If given and the value had an initial expiry time set, renew the + expiry time for ``renew_for`` seconds. If the value has not been set + with an expiry time this is a no-op. + + Returns: + Session data as bytes if found and not expired, None otherwise. + """ + raise NotImplementedError + + @abstractmethod + async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None: + """Store a session value. + + Args: + key: Session ID. + value: Session data (will be converted to bytes if string). + expires_in: Time in seconds or timedelta before expiration. + """ + raise NotImplementedError + + @abstractmethod + async def delete(self, key: str) -> None: + """Delete a session by key. + + Args: + key: Session ID to delete. + """ + raise NotImplementedError + + @abstractmethod + async def delete_all(self) -> None: + """Delete all sessions from the store.""" + raise NotImplementedError + + @abstractmethod + async def exists(self, key: str) -> bool: + """Check if a session key exists and is not expired. + + Args: + key: Session ID to check. + + Returns: + True if the session exists and is not expired. + """ + raise NotImplementedError + + @abstractmethod + async def expires_in(self, key: str) -> "int | None": + """Get the time in seconds until the session expires. + + Args: + key: Session ID to check. + + Returns: + Seconds until expiration, or None if no expiry or key doesn't exist. + """ + raise NotImplementedError + + @abstractmethod + async def delete_expired(self) -> int: + """Delete all expired sessions. + + Returns: + Number of sessions deleted. + """ + raise NotImplementedError + + @abstractmethod + async def create_table(self) -> None: + """Create the session table if it doesn't exist.""" + raise NotImplementedError + + @abstractmethod + def _get_create_table_sql(self) -> str: + """Get the CREATE TABLE SQL for this database dialect. + + Returns: + SQL statement to create the sessions table. + """ + raise NotImplementedError + + @abstractmethod + def _get_drop_table_sql(self) -> "list[str]": + """Get the DROP TABLE SQL statements for this database dialect. + + Returns: + List of SQL statements to drop the table and all indexes. + Order matters: drop indexes before table. + + Notes: + Should use IF EXISTS or dialect-specific error handling + to allow idempotent migrations. + """ + raise NotImplementedError + + async def __aenter__(self) -> "BaseSQLSpecStore": + """Enter context manager.""" + return self + + async def __aexit__( + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" + ) -> None: + """Exit context manager.""" + return + + def _calculate_expires_at(self, expires_in: "int | timedelta | None") -> "datetime | None": + """Calculate expiration timestamp from expires_in. + + Args: + expires_in: Seconds or timedelta until expiration. + + Returns: + UTC datetime of expiration, or None if no expiration. + """ + if expires_in is None: + return None + + expires_in_seconds = int(expires_in.total_seconds()) if isinstance(expires_in, timedelta) else expires_in + + if expires_in_seconds <= 0: + return None + + return datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds) + + def _value_to_bytes(self, value: "str | bytes") -> bytes: + """Convert value to bytes if needed. + + Args: + value: String or bytes value. + + Returns: + Value as bytes. + """ + if isinstance(value, str): + return value.encode("utf-8") + return value + + @staticmethod + def _validate_table_name(table_name: str) -> None: + """Validate table name for SQL safety. + + Args: + table_name: Table name to validate. + + Raises: + ValueError: If table name is invalid. + + Notes: + - Must start with letter or underscore + - Can only contain letters, numbers, and underscores + - Maximum length is 63 characters (PostgreSQL limit) + - Prevents SQL injection in table names + """ + if not table_name: + msg = "Table name cannot be empty" + raise ValueError(msg) + + if len(table_name) > MAX_TABLE_NAME_LENGTH: + msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})" + raise ValueError(msg) + + if not VALID_TABLE_NAME_PATTERN.match(table_name): + msg = f"Invalid table name: {table_name!r}. Must start with letter/underscore and contain only alphanumeric characters and underscores" + raise ValueError(msg) diff --git a/sqlspec/typing.py b/sqlspec/typing.py index 0c0cecf0..afe36902 100644 --- a/sqlspec/typing.py +++ b/sqlspec/typing.py @@ -1,7 +1,7 @@ # pyright: ignore[reportAttributeAccessIssue] from collections.abc import Iterator, Mapping from functools import lru_cache -from typing import TYPE_CHECKING, Annotated, Any, Protocol, TypeAlias +from typing import TYPE_CHECKING, Annotated, Any, Protocol, TypeAlias, _TypedDict # pyright: ignore from typing_extensions import TypeVar @@ -105,7 +105,9 @@ def __len__(self) -> int: ... TupleRow: TypeAlias = "tuple[Any, ...]" """Type variable for TupleRow types.""" -SupportedSchemaModel: TypeAlias = "DictLike | StructStub | BaseModelStub | DataclassProtocol | AttrsInstanceStub" +SupportedSchemaModel: TypeAlias = ( + DictLike | StructStub | BaseModelStub | DataclassProtocol | AttrsInstanceStub | _TypedDict +) """Type alias for pydantic or msgspec models. :class:`msgspec.Struct` | :class:`pydantic.BaseModel` | :class:`DataclassProtocol` | :class:`AttrsInstance` @@ -147,7 +149,7 @@ def __len__(self) -> int: ... """ BulkModelDict: TypeAlias = ( - "Sequence[dict[str, Any] | DictLike | StructStub | BaseModelStub | DataclassProtocol | AttrsInstanceStub] | Any" + "Sequence[dict[str, Any] | DictLike | StructStub | BaseModelStub | DataclassProtocol | AttrsInstanceStub ] | Any" ) """Type alias for bulk model dictionaries. diff --git a/sqlspec/utils/sync_tools.py b/sqlspec/utils/sync_tools.py index cf5e5cf6..19e91776 100644 --- a/sqlspec/utils/sync_tools.py +++ b/sqlspec/utils/sync_tools.py @@ -8,6 +8,7 @@ import asyncio import functools import inspect +import os import sys from contextlib import AbstractAsyncContextManager, AbstractContextManager from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast @@ -47,12 +48,20 @@ def __init__(self, total_tokens: int) -> None: """ self._total_tokens = total_tokens self._semaphore_instance: asyncio.Semaphore | None = None + self._pid: int | None = None @property def _semaphore(self) -> asyncio.Semaphore: - """Lazy initialization of asyncio.Semaphore for Python 3.9 compatibility.""" - if self._semaphore_instance is None: + """Lazy initialization of asyncio.Semaphore with per-process tracking. + + Reinitializes the semaphore if running in a new process (detected via PID). + This ensures pytest-xdist workers each get their own semaphore bound to + their event loop, preventing cross-process deadlocks. + """ + current_pid = os.getpid() + if self._semaphore_instance is None or self._pid != current_pid: self._semaphore_instance = asyncio.Semaphore(self._total_tokens) + self._pid = current_pid return self._semaphore_instance async def acquire(self) -> None: @@ -72,6 +81,7 @@ def total_tokens(self) -> int: def total_tokens(self, value: int) -> None: self._total_tokens = value self._semaphore_instance = None + self._pid = None async def __aenter__(self) -> None: """Async context manager entry.""" @@ -84,7 +94,7 @@ async def __aexit__( self.release() -_default_limiter = CapacityLimiter(15) +_default_limiter = CapacityLimiter(1000) def run_(async_function: "Callable[ParamSpecT, Coroutine[Any, Any, ReturnT]]") -> "Callable[ParamSpecT, ReturnT]": diff --git a/tests/conftest.py b/tests/conftest.py index 714061af..f8731e91 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,6 +29,11 @@ def pytest_addoption(parser: pytest.Parser) -> None: @pytest.fixture def anyio_backend() -> str: + """Configure AnyIO to use asyncio backend only. + + Disables trio backend to prevent duplicate test runs and compatibility issues + with pytest-xdist parallel execution. + """ return "asyncio" diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 99aba378..7ce9d731 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -2,21 +2,11 @@ from __future__ import annotations -import asyncio -from collections.abc import Generator from typing import Any import pytest -@pytest.fixture(scope="session") -def event_loop() -> Generator[asyncio.AbstractEventLoop, None, None]: - """Create an event loop for async tests.""" - loop = asyncio.new_event_loop() - yield loop - loop.close() - - @pytest.fixture def sample_data() -> list[tuple[str, int]]: """Standard sample data for testing across adapters.""" diff --git a/tests/integration/test_adapters/test_adbc/test_extensions/__init__.py b/tests/integration/test_adapters/test_adbc/test_extensions/__init__.py new file mode 100644 index 00000000..3e9b620b --- /dev/null +++ b/tests/integration/test_adapters/test_adbc/test_extensions/__init__.py @@ -0,0 +1 @@ +"""ADBC extension integration tests.""" diff --git a/tests/integration/test_adapters/test_adbc/test_extensions/test_litestar/__init__.py b/tests/integration/test_adapters/test_adbc/test_extensions/test_litestar/__init__.py new file mode 100644 index 00000000..8f1fbc6e --- /dev/null +++ b/tests/integration/test_adapters/test_adbc/test_extensions/test_litestar/__init__.py @@ -0,0 +1 @@ +"""ADBC Litestar integration tests.""" diff --git a/tests/integration/test_adapters/test_adbc/test_extensions/test_litestar/test_store.py b/tests/integration/test_adapters/test_adbc/test_extensions/test_litestar/test_store.py new file mode 100644 index 00000000..3e2bb5ac --- /dev/null +++ b/tests/integration/test_adapters/test_adbc/test_extensions/test_litestar/test_store.py @@ -0,0 +1,261 @@ +"""Integration tests for ADBC session store.""" + +import asyncio +from collections.abc import AsyncGenerator +from datetime import timedelta + +import pytest +from pytest_databases.docker.postgres import PostgresService + +from sqlspec.adapters.adbc.config import AdbcConfig +from sqlspec.adapters.adbc.litestar.store import ADBCStore + +pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.adbc, pytest.mark.integration] + + +@pytest.fixture +async def adbc_store(postgres_service: PostgresService) -> AsyncGenerator[ADBCStore, None]: + """Create ADBC store with PostgreSQL backend.""" + config = AdbcConfig( + connection_config={ + "uri": f"postgresql://{postgres_service.user}:{postgres_service.password}@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + } + ) + store = ADBCStore(config, table_name="test_adbc_sessions") + await store.create_table() + yield store + try: + await store.delete_all() + except Exception: + pass + + +async def test_store_create_table(adbc_store: ADBCStore) -> None: + """Test table creation.""" + assert adbc_store.table_name == "test_adbc_sessions" + + +async def test_store_set_and_get(adbc_store: ADBCStore) -> None: + """Test basic set and get operations.""" + test_data = b"test session data" + await adbc_store.set("session_123", test_data) + + result = await adbc_store.get("session_123") + assert result == test_data + + +async def test_store_get_nonexistent(adbc_store: ADBCStore) -> None: + """Test getting a non-existent session returns None.""" + result = await adbc_store.get("nonexistent") + assert result is None + + +async def test_store_set_with_string_value(adbc_store: ADBCStore) -> None: + """Test setting a string value (should be converted to bytes).""" + await adbc_store.set("session_str", "string data") + + result = await adbc_store.get("session_str") + assert result == b"string data" + + +async def test_store_delete(adbc_store: ADBCStore) -> None: + """Test delete operation.""" + await adbc_store.set("session_to_delete", b"data") + + assert await adbc_store.exists("session_to_delete") + + await adbc_store.delete("session_to_delete") + + assert not await adbc_store.exists("session_to_delete") + assert await adbc_store.get("session_to_delete") is None + + +async def test_store_delete_nonexistent(adbc_store: ADBCStore) -> None: + """Test deleting a non-existent session is a no-op.""" + await adbc_store.delete("nonexistent") + + +async def test_store_expiration_with_int(adbc_store: ADBCStore) -> None: + """Test session expiration with integer seconds.""" + await adbc_store.set("expiring_session", b"data", expires_in=1) + + assert await adbc_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await adbc_store.get("expiring_session") + assert result is None + assert not await adbc_store.exists("expiring_session") + + +async def test_store_expiration_with_timedelta(adbc_store: ADBCStore) -> None: + """Test session expiration with timedelta.""" + await adbc_store.set("expiring_session", b"data", expires_in=timedelta(seconds=1)) + + assert await adbc_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await adbc_store.get("expiring_session") + assert result is None + + +async def test_store_no_expiration(adbc_store: ADBCStore) -> None: + """Test session without expiration persists.""" + await adbc_store.set("permanent_session", b"data") + + expires_in = await adbc_store.expires_in("permanent_session") + assert expires_in is None + + assert await adbc_store.exists("permanent_session") + + +async def test_store_expires_in(adbc_store: ADBCStore) -> None: + """Test expires_in returns correct time.""" + await adbc_store.set("timed_session", b"data", expires_in=10) + + expires_in = await adbc_store.expires_in("timed_session") + assert expires_in is not None + assert 8 <= expires_in <= 10 + + +async def test_store_expires_in_expired(adbc_store: ADBCStore) -> None: + """Test expires_in returns 0 for expired session.""" + await adbc_store.set("expired_session", b"data", expires_in=1) + + await asyncio.sleep(1.1) + + expires_in = await adbc_store.expires_in("expired_session") + assert expires_in == 0 + + +async def test_store_cleanup(adbc_store: ADBCStore) -> None: + """Test delete_expired removes only expired sessions.""" + await adbc_store.set("active_session", b"data", expires_in=60) + await adbc_store.set("expired_session_1", b"data", expires_in=1) + await adbc_store.set("expired_session_2", b"data", expires_in=1) + await adbc_store.set("permanent_session", b"data") + + await asyncio.sleep(1.1) + + count = await adbc_store.delete_expired() + assert count == 2 + + assert await adbc_store.exists("active_session") + assert await adbc_store.exists("permanent_session") + assert not await adbc_store.exists("expired_session_1") + assert not await adbc_store.exists("expired_session_2") + + +async def test_store_upsert(adbc_store: ADBCStore) -> None: + """Test updating existing session (UPSERT).""" + await adbc_store.set("session_upsert", b"original data") + + result = await adbc_store.get("session_upsert") + assert result == b"original data" + + await adbc_store.set("session_upsert", b"updated data") + + result = await adbc_store.get("session_upsert") + assert result == b"updated data" + + +async def test_store_upsert_with_expiration_change(adbc_store: ADBCStore) -> None: + """Test updating session expiration.""" + await adbc_store.set("session_exp", b"data", expires_in=60) + + expires_in = await adbc_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in > 50 + + await adbc_store.set("session_exp", b"data", expires_in=10) + + expires_in = await adbc_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in <= 10 + + +async def test_store_renew_for(adbc_store: ADBCStore) -> None: + """Test renewing session expiration on get.""" + await adbc_store.set("session_renew", b"data", expires_in=5) + + await asyncio.sleep(3) + + expires_before = await adbc_store.expires_in("session_renew") + assert expires_before is not None + assert expires_before <= 2 + + result = await adbc_store.get("session_renew", renew_for=10) + assert result == b"data" + + expires_after = await adbc_store.expires_in("session_renew") + assert expires_after is not None + assert expires_after > 8 + + +async def test_store_large_data(adbc_store: ADBCStore) -> None: + """Test storing large session data (>1MB).""" + large_data = b"x" * (1024 * 1024 + 100) + + await adbc_store.set("large_session", large_data) + + result = await adbc_store.get("large_session") + assert result is not None + assert result == large_data + assert len(result) > 1024 * 1024 + + +async def test_store_delete_all(adbc_store: ADBCStore) -> None: + """Test delete_all removes all sessions.""" + await adbc_store.set("session1", b"data1") + await adbc_store.set("session2", b"data2") + await adbc_store.set("session3", b"data3") + + assert await adbc_store.exists("session1") + assert await adbc_store.exists("session2") + assert await adbc_store.exists("session3") + + await adbc_store.delete_all() + + assert not await adbc_store.exists("session1") + assert not await adbc_store.exists("session2") + assert not await adbc_store.exists("session3") + + +async def test_store_exists(adbc_store: ADBCStore) -> None: + """Test exists method.""" + assert not await adbc_store.exists("test_session") + + await adbc_store.set("test_session", b"data") + + assert await adbc_store.exists("test_session") + + +async def test_store_context_manager(adbc_store: ADBCStore) -> None: + """Test store can be used as async context manager.""" + async with adbc_store: + await adbc_store.set("ctx_session", b"data") + + result = await adbc_store.get("ctx_session") + assert result == b"data" + + +async def test_sync_to_thread_concurrency(adbc_store: ADBCStore) -> None: + """Test concurrent access via sync_to_thread wrapper. + + ADBC with PostgreSQL supports concurrent reads and writes. + We test concurrent writes followed by concurrent reads. + """ + + async def write_session(session_id: int) -> None: + await adbc_store.set(f"session_{session_id}", f"data_{session_id}".encode()) + + await asyncio.gather(*[write_session(i) for i in range(10)]) + + async def read_session(session_id: int) -> "bytes | None": + return await adbc_store.get(f"session_{session_id}") + + results = await asyncio.gather(*[read_session(i) for i in range(10)]) + + for i, result in enumerate(results): + assert result == f"data_{i}".encode() diff --git a/tests/integration/test_adapters/test_aiosqlite/test_extensions/__init__.py b/tests/integration/test_adapters/test_aiosqlite/test_extensions/__init__.py new file mode 100644 index 00000000..a4aa669c --- /dev/null +++ b/tests/integration/test_adapters/test_aiosqlite/test_extensions/__init__.py @@ -0,0 +1 @@ +"""AioSQLite extension integration tests.""" diff --git a/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/__init__.py b/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/__init__.py new file mode 100644 index 00000000..d71e6124 --- /dev/null +++ b/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/__init__.py @@ -0,0 +1 @@ +"""Integration tests for AioSQLite Litestar extensions.""" diff --git a/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/test_store.py b/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/test_store.py new file mode 100644 index 00000000..981b76d7 --- /dev/null +++ b/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/test_store.py @@ -0,0 +1,232 @@ +"""Integration tests for AioSQLite session store.""" + +import asyncio +from collections.abc import AsyncGenerator +from datetime import timedelta + +import pytest + +from sqlspec.adapters.aiosqlite.config import AiosqliteConfig +from sqlspec.adapters.aiosqlite.litestar.store import AiosqliteStore + +pytestmark = [pytest.mark.aiosqlite, pytest.mark.integration] + + +@pytest.fixture +async def aiosqlite_store() -> "AsyncGenerator[AiosqliteStore, None]": + """Create AioSQLite store with in-memory database.""" + config = AiosqliteConfig(pool_config={"database": ":memory:"}) + store = AiosqliteStore(config, table_name="test_sessions") + await store.create_table() + yield store + await store.delete_all() + + +async def test_store_create_table(aiosqlite_store: AiosqliteStore) -> None: + """Test table creation.""" + assert aiosqlite_store.table_name == "test_sessions" + + +async def test_store_set_and_get(aiosqlite_store: AiosqliteStore) -> None: + """Test basic set and get operations.""" + test_data = b"test session data" + await aiosqlite_store.set("session_123", test_data) + + result = await aiosqlite_store.get("session_123") + assert result == test_data + + +async def test_store_get_nonexistent(aiosqlite_store: AiosqliteStore) -> None: + """Test getting a non-existent session returns None.""" + result = await aiosqlite_store.get("nonexistent") + assert result is None + + +async def test_store_set_with_string_value(aiosqlite_store: AiosqliteStore) -> None: + """Test setting a string value (should be converted to bytes).""" + await aiosqlite_store.set("session_str", "string data") + + result = await aiosqlite_store.get("session_str") + assert result == b"string data" + + +async def test_store_delete(aiosqlite_store: AiosqliteStore) -> None: + """Test delete operation.""" + await aiosqlite_store.set("session_to_delete", b"data") + + assert await aiosqlite_store.exists("session_to_delete") + + await aiosqlite_store.delete("session_to_delete") + + assert not await aiosqlite_store.exists("session_to_delete") + assert await aiosqlite_store.get("session_to_delete") is None + + +async def test_store_delete_nonexistent(aiosqlite_store: AiosqliteStore) -> None: + """Test deleting a non-existent session is a no-op.""" + await aiosqlite_store.delete("nonexistent") + + +async def test_store_expiration_with_int(aiosqlite_store: AiosqliteStore) -> None: + """Test session expiration with integer seconds.""" + await aiosqlite_store.set("expiring_session", b"data", expires_in=1) + + assert await aiosqlite_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await aiosqlite_store.get("expiring_session") + assert result is None + assert not await aiosqlite_store.exists("expiring_session") + + +async def test_store_expiration_with_timedelta(aiosqlite_store: AiosqliteStore) -> None: + """Test session expiration with timedelta.""" + await aiosqlite_store.set("expiring_session", b"data", expires_in=timedelta(seconds=1)) + + assert await aiosqlite_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await aiosqlite_store.get("expiring_session") + assert result is None + + +async def test_store_no_expiration(aiosqlite_store: AiosqliteStore) -> None: + """Test session without expiration persists.""" + await aiosqlite_store.set("permanent_session", b"data") + + expires_in = await aiosqlite_store.expires_in("permanent_session") + assert expires_in is None + + assert await aiosqlite_store.exists("permanent_session") + + +async def test_store_expires_in(aiosqlite_store: AiosqliteStore) -> None: + """Test expires_in returns correct time.""" + await aiosqlite_store.set("timed_session", b"data", expires_in=10) + + expires_in = await aiosqlite_store.expires_in("timed_session") + assert expires_in is not None + assert 8 <= expires_in <= 10 + + +async def test_store_expires_in_expired(aiosqlite_store: AiosqliteStore) -> None: + """Test expires_in returns 0 for expired session.""" + await aiosqlite_store.set("expired_session", b"data", expires_in=1) + + await asyncio.sleep(1.1) + + expires_in = await aiosqlite_store.expires_in("expired_session") + assert expires_in == 0 + + +async def test_store_cleanup(aiosqlite_store: AiosqliteStore) -> None: + """Test delete_expired removes only expired sessions.""" + await aiosqlite_store.set("active_session", b"data", expires_in=60) + await aiosqlite_store.set("expired_session_1", b"data", expires_in=1) + await aiosqlite_store.set("expired_session_2", b"data", expires_in=1) + await aiosqlite_store.set("permanent_session", b"data") + + await asyncio.sleep(1.1) + + count = await aiosqlite_store.delete_expired() + assert count == 2 + + assert await aiosqlite_store.exists("active_session") + assert await aiosqlite_store.exists("permanent_session") + assert not await aiosqlite_store.exists("expired_session_1") + assert not await aiosqlite_store.exists("expired_session_2") + + +async def test_store_upsert(aiosqlite_store: AiosqliteStore) -> None: + """Test updating existing session (UPSERT).""" + await aiosqlite_store.set("session_upsert", b"original data") + + result = await aiosqlite_store.get("session_upsert") + assert result == b"original data" + + await aiosqlite_store.set("session_upsert", b"updated data") + + result = await aiosqlite_store.get("session_upsert") + assert result == b"updated data" + + +async def test_store_upsert_with_expiration_change(aiosqlite_store: AiosqliteStore) -> None: + """Test updating session expiration.""" + await aiosqlite_store.set("session_exp", b"data", expires_in=60) + + expires_in = await aiosqlite_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in > 50 + + await aiosqlite_store.set("session_exp", b"data", expires_in=10) + + expires_in = await aiosqlite_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in <= 10 + + +async def test_store_renew_for(aiosqlite_store: AiosqliteStore) -> None: + """Test renewing session expiration on get.""" + await aiosqlite_store.set("session_renew", b"data", expires_in=5) + + await asyncio.sleep(3) + + expires_before = await aiosqlite_store.expires_in("session_renew") + assert expires_before is not None + assert expires_before <= 2 + + result = await aiosqlite_store.get("session_renew", renew_for=10) + assert result == b"data" + + expires_after = await aiosqlite_store.expires_in("session_renew") + assert expires_after is not None + assert expires_after > 8 + + +async def test_store_large_data(aiosqlite_store: AiosqliteStore) -> None: + """Test storing large session data (>1MB).""" + large_data = b"x" * (1024 * 1024 + 100) + + await aiosqlite_store.set("large_session", large_data) + + result = await aiosqlite_store.get("large_session") + assert result is not None + assert result == large_data + assert len(result) > 1024 * 1024 + + +async def test_store_delete_all(aiosqlite_store: AiosqliteStore) -> None: + """Test delete_all removes all sessions.""" + await aiosqlite_store.set("session1", b"data1") + await aiosqlite_store.set("session2", b"data2") + await aiosqlite_store.set("session3", b"data3") + + assert await aiosqlite_store.exists("session1") + assert await aiosqlite_store.exists("session2") + assert await aiosqlite_store.exists("session3") + + await aiosqlite_store.delete_all() + + assert not await aiosqlite_store.exists("session1") + assert not await aiosqlite_store.exists("session2") + assert not await aiosqlite_store.exists("session3") + + +async def test_store_exists(aiosqlite_store: AiosqliteStore) -> None: + """Test exists method.""" + assert not await aiosqlite_store.exists("test_session") + + await aiosqlite_store.set("test_session", b"data") + + assert await aiosqlite_store.exists("test_session") + + +async def test_store_context_manager(aiosqlite_store: AiosqliteStore) -> None: + """Test store can be used as async context manager.""" + async with aiosqlite_store: + await aiosqlite_store.set("ctx_session", b"data") + + result = await aiosqlite_store.get("ctx_session") + assert result == b"data" diff --git a/tests/integration/test_adapters/test_asyncmy/test_asyncmy_features.py b/tests/integration/test_adapters/test_asyncmy/test_asyncmy_features.py index 8d257ea1..08625716 100644 --- a/tests/integration/test_adapters/test_asyncmy/test_asyncmy_features.py +++ b/tests/integration/test_adapters/test_asyncmy/test_asyncmy_features.py @@ -52,7 +52,6 @@ async def asyncmy_pooled_session(mysql_service: MySQLService) -> AsyncGenerator[ yield session -@pytest.mark.asyncio async def test_asyncmy_mysql_json_operations(asyncmy_pooled_session: AsyncmyDriver) -> None: """Test MySQL JSON column operations.""" driver = asyncmy_pooled_session @@ -87,7 +86,6 @@ async def test_asyncmy_mysql_json_operations(asyncmy_pooled_session: AsyncmyDriv assert contains_result.get_data()[0]["count"] == 1 -@pytest.mark.asyncio async def test_asyncmy_mysql_specific_sql_features(asyncmy_pooled_session: AsyncmyDriver) -> None: """Test MySQL-specific SQL features and syntax.""" driver = asyncmy_pooled_session @@ -131,7 +129,6 @@ async def test_asyncmy_mysql_specific_sql_features(asyncmy_pooled_session: Async assert "important" in enum_row["tags"] -@pytest.mark.asyncio async def test_asyncmy_transaction_isolation_levels(asyncmy_pooled_session: AsyncmyDriver) -> None: """Test MySQL transaction isolation level handling.""" driver = asyncmy_pooled_session @@ -158,7 +155,6 @@ async def test_asyncmy_transaction_isolation_levels(asyncmy_pooled_session: Asyn assert committed_result.get_data()[0]["value"] == "transaction_data" -@pytest.mark.asyncio async def test_asyncmy_stored_procedures(asyncmy_pooled_session: AsyncmyDriver) -> None: """Test stored procedure execution.""" driver = asyncmy_pooled_session @@ -185,7 +181,6 @@ async def test_asyncmy_stored_procedures(asyncmy_pooled_session: AsyncmyDriver) await driver.execute("CALL simple_procedure(?)", (5,)) -@pytest.mark.asyncio async def test_asyncmy_bulk_operations_performance(asyncmy_pooled_session: AsyncmyDriver) -> None: """Test bulk operations for performance characteristics.""" driver = asyncmy_pooled_session @@ -220,7 +215,6 @@ async def test_asyncmy_bulk_operations_performance(asyncmy_pooled_session: Async assert select_result.get_data()[99]["sequence_num"] == 99 -@pytest.mark.asyncio async def test_asyncmy_error_recovery(asyncmy_pooled_session: AsyncmyDriver) -> None: """Test error handling and connection recovery.""" driver = asyncmy_pooled_session @@ -247,7 +241,6 @@ async def test_asyncmy_error_recovery(asyncmy_pooled_session: AsyncmyDriver) -> assert final_result.get_data()[0]["value"] == "test_value" -@pytest.mark.asyncio async def test_asyncmy_sql_object_advanced_features(asyncmy_pooled_session: AsyncmyDriver) -> None: """Test SQL object integration with advanced AsyncMy features.""" driver = asyncmy_pooled_session diff --git a/tests/integration/test_adapters/test_asyncmy/test_config.py b/tests/integration/test_adapters/test_asyncmy/test_config.py index 43cec5d1..3ef66d41 100644 --- a/tests/integration/test_adapters/test_asyncmy/test_config.py +++ b/tests/integration/test_adapters/test_asyncmy/test_config.py @@ -79,7 +79,6 @@ def test_asyncmy_config_initialization() -> None: assert config.statement_config is custom_statement_config -@pytest.mark.asyncio async def test_asyncmy_config_provide_session(mysql_service: MySQLService) -> None: """Test Asyncmy config provide_session context manager.""" diff --git a/tests/integration/test_adapters/test_asyncmy/test_driver.py b/tests/integration/test_adapters/test_asyncmy/test_driver.py index c72c106d..5e0eb9e9 100644 --- a/tests/integration/test_adapters/test_asyncmy/test_driver.py +++ b/tests/integration/test_adapters/test_asyncmy/test_driver.py @@ -61,7 +61,6 @@ async def asyncmy_session(mysql_service: MySQLService) -> AsyncGenerator[Asyncmy yield session -@pytest.mark.asyncio async def test_asyncmy_basic_crud(asyncmy_driver: AsyncmyDriver) -> None: """Test basic CRUD operations.""" driver = asyncmy_driver @@ -89,7 +88,6 @@ async def test_asyncmy_basic_crud(asyncmy_driver: AsyncmyDriver) -> None: assert verify_result.get_data()[0]["count"] == 0 -@pytest.mark.asyncio async def test_asyncmy_parameter_styles(asyncmy_driver: AsyncmyDriver) -> None: """Test different parameter binding styles.""" driver = asyncmy_driver @@ -108,7 +106,6 @@ async def test_asyncmy_parameter_styles(asyncmy_driver: AsyncmyDriver) -> None: assert select_result.get_data()[1]["value"] == 20 -@pytest.mark.asyncio async def test_asyncmy_execute_many(asyncmy_driver: AsyncmyDriver) -> None: """Test execute_many functionality.""" driver = asyncmy_driver @@ -126,7 +123,6 @@ async def test_asyncmy_execute_many(asyncmy_driver: AsyncmyDriver) -> None: assert select_result.get_data()[0]["value"] == 100 -@pytest.mark.asyncio async def test_asyncmy_execute_script(asyncmy_driver: AsyncmyDriver) -> None: """Test script execution with multiple statements.""" driver = asyncmy_driver @@ -148,7 +144,6 @@ async def test_asyncmy_execute_script(asyncmy_driver: AsyncmyDriver) -> None: assert select_result.get_data()[1]["value"] == 4000 -@pytest.mark.asyncio async def test_asyncmy_data_types(asyncmy_driver: AsyncmyDriver) -> None: """Test handling of various MySQL data types.""" driver = asyncmy_driver @@ -189,7 +184,6 @@ async def test_asyncmy_data_types(asyncmy_driver: AsyncmyDriver) -> None: assert row["bool_col"] == 1 -@pytest.mark.asyncio async def test_asyncmy_transaction_management(asyncmy_driver: AsyncmyDriver) -> None: """Test transaction management (begin, commit, rollback).""" driver = asyncmy_driver @@ -209,7 +203,6 @@ async def test_asyncmy_transaction_management(asyncmy_driver: AsyncmyDriver) -> assert result.get_data()[0]["count"] == 0 -@pytest.mark.asyncio async def test_asyncmy_null_parameters(asyncmy_driver: AsyncmyDriver) -> None: """Test handling of NULL parameters.""" driver = asyncmy_driver @@ -223,7 +216,6 @@ async def test_asyncmy_null_parameters(asyncmy_driver: AsyncmyDriver) -> None: assert select_result.get_data()[0]["value"] is None -@pytest.mark.asyncio async def test_asyncmy_error_handling(asyncmy_driver: AsyncmyDriver) -> None: """Test error handling and exception wrapping.""" driver = asyncmy_driver @@ -237,7 +229,6 @@ async def test_asyncmy_error_handling(asyncmy_driver: AsyncmyDriver) -> None: await driver.execute("INSERT INTO test_table (id, name, value) VALUES (?, ?, ?)", (1, "user2", 200)) -@pytest.mark.asyncio async def test_asyncmy_large_result_set(asyncmy_driver: AsyncmyDriver) -> None: """Test handling of large result sets.""" driver = asyncmy_driver @@ -252,7 +243,6 @@ async def test_asyncmy_large_result_set(asyncmy_driver: AsyncmyDriver) -> None: assert result.get_data()[99]["name"] == "user_99" -@pytest.mark.asyncio async def test_asyncmy_mysql_specific_features(asyncmy_driver: AsyncmyDriver) -> None: """Test MySQL-specific features and SQL constructs.""" driver = asyncmy_driver @@ -269,7 +259,6 @@ async def test_asyncmy_mysql_specific_features(asyncmy_driver: AsyncmyDriver) -> assert select_result.get_data()[0]["value"] == 250 -@pytest.mark.asyncio async def test_asyncmy_complex_queries(asyncmy_driver: AsyncmyDriver) -> None: """Test complex SQL queries with JOINs, subqueries, etc.""" driver = asyncmy_driver @@ -304,7 +293,6 @@ async def test_asyncmy_complex_queries(asyncmy_driver: AsyncmyDriver) -> None: assert row["age"] == 30 -@pytest.mark.asyncio async def test_asyncmy_edge_cases(asyncmy_driver: AsyncmyDriver) -> None: """Test edge cases and boundary conditions.""" driver = asyncmy_driver @@ -328,7 +316,6 @@ async def test_asyncmy_edge_cases(asyncmy_driver: AsyncmyDriver) -> None: assert select_result.get_data()[1]["value"] is None -@pytest.mark.asyncio async def test_asyncmy_result_metadata(asyncmy_driver: AsyncmyDriver) -> None: """Test SQL result metadata and properties.""" driver = asyncmy_driver @@ -350,7 +337,6 @@ async def test_asyncmy_result_metadata(asyncmy_driver: AsyncmyDriver) -> None: assert len(empty_result.get_data()) == 0 -@pytest.mark.asyncio async def test_asyncmy_sql_object_execution(asyncmy_driver: AsyncmyDriver) -> None: """Test execution of SQL objects.""" driver = asyncmy_driver @@ -372,7 +358,6 @@ async def test_asyncmy_sql_object_execution(asyncmy_driver: AsyncmyDriver) -> No assert select_result.operation_type == "SELECT" -@pytest.mark.asyncio async def test_asyncmy_for_update_locking(asyncmy_driver: AsyncmyDriver) -> None: """Test FOR UPDATE row locking with MySQL.""" from sqlspec import sql @@ -399,7 +384,6 @@ async def test_asyncmy_for_update_locking(asyncmy_driver: AsyncmyDriver) -> None raise -@pytest.mark.asyncio async def test_asyncmy_for_update_skip_locked(asyncmy_driver: AsyncmyDriver) -> None: """Test FOR UPDATE SKIP LOCKED with MySQL (MySQL 8.0+ feature).""" from sqlspec import sql @@ -425,7 +409,6 @@ async def test_asyncmy_for_update_skip_locked(asyncmy_driver: AsyncmyDriver) -> raise -@pytest.mark.asyncio async def test_asyncmy_for_share_locking(asyncmy_driver: AsyncmyDriver) -> None: """Test FOR SHARE row locking with MySQL.""" from sqlspec import sql diff --git a/tests/integration/test_adapters/test_asyncmy/test_extensions/__init__.py b/tests/integration/test_adapters/test_asyncmy/test_extensions/__init__.py new file mode 100644 index 00000000..508f8451 --- /dev/null +++ b/tests/integration/test_adapters/test_asyncmy/test_extensions/__init__.py @@ -0,0 +1 @@ +"""AsyncMy extensions integration tests.""" diff --git a/tests/integration/test_adapters/test_asyncmy/test_extensions/test_litestar/__init__.py b/tests/integration/test_adapters/test_asyncmy/test_extensions/test_litestar/__init__.py new file mode 100644 index 00000000..6907eaec --- /dev/null +++ b/tests/integration/test_adapters/test_asyncmy/test_extensions/test_litestar/__init__.py @@ -0,0 +1 @@ +"""AsyncMy Litestar integration tests.""" diff --git a/tests/integration/test_adapters/test_asyncmy/test_extensions/test_litestar/test_store.py b/tests/integration/test_adapters/test_asyncmy/test_extensions/test_litestar/test_store.py new file mode 100644 index 00000000..371eca28 --- /dev/null +++ b/tests/integration/test_adapters/test_asyncmy/test_extensions/test_litestar/test_store.py @@ -0,0 +1,251 @@ +"""Integration tests for AsyncMy session store.""" + +import asyncio +from collections.abc import AsyncGenerator +from datetime import timedelta + +import pytest +from pytest_databases.docker.mysql import MySQLService + +from sqlspec.adapters.asyncmy.config import AsyncmyConfig +from sqlspec.adapters.asyncmy.litestar.store import AsyncmyStore + +pytestmark = [pytest.mark.xdist_group("mysql"), pytest.mark.asyncmy, pytest.mark.integration] + + +@pytest.fixture +async def asyncmy_store(mysql_service: MySQLService) -> "AsyncGenerator[AsyncmyStore, None]": + """Create AsyncMy store with test database.""" + config = AsyncmyConfig( + pool_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "database": mysql_service.db, + } + ) + store = AsyncmyStore(config, table_name="test_asyncmy_sessions") + try: + await store.create_table() + yield store + try: + await store.delete_all() + except Exception: + pass + finally: + try: + if config.pool_instance: + await config.close_pool() + except Exception: + pass + + +async def test_store_create_table(asyncmy_store: AsyncmyStore) -> None: + """Test table creation.""" + assert asyncmy_store.table_name == "test_asyncmy_sessions" + + +async def test_store_set_and_get(asyncmy_store: AsyncmyStore) -> None: + """Test basic set and get operations.""" + test_data = b"test session data" + await asyncmy_store.set("session_123", test_data) + + result = await asyncmy_store.get("session_123") + assert result == test_data + + +async def test_store_get_nonexistent(asyncmy_store: AsyncmyStore) -> None: + """Test getting a non-existent session returns None.""" + result = await asyncmy_store.get("nonexistent") + assert result is None + + +async def test_store_set_with_string_value(asyncmy_store: AsyncmyStore) -> None: + """Test setting a string value (should be converted to bytes).""" + await asyncmy_store.set("session_str", "string data") + + result = await asyncmy_store.get("session_str") + assert result == b"string data" + + +async def test_store_delete(asyncmy_store: AsyncmyStore) -> None: + """Test delete operation.""" + await asyncmy_store.set("session_to_delete", b"data") + + assert await asyncmy_store.exists("session_to_delete") + + await asyncmy_store.delete("session_to_delete") + + assert not await asyncmy_store.exists("session_to_delete") + assert await asyncmy_store.get("session_to_delete") is None + + +async def test_store_delete_nonexistent(asyncmy_store: AsyncmyStore) -> None: + """Test deleting a non-existent session is a no-op.""" + await asyncmy_store.delete("nonexistent") + + +async def test_store_expiration_with_int(asyncmy_store: AsyncmyStore) -> None: + """Test session expiration with integer seconds.""" + await asyncmy_store.set("expiring_session", b"data", expires_in=1) + + assert await asyncmy_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await asyncmy_store.get("expiring_session") + assert result is None + assert not await asyncmy_store.exists("expiring_session") + + +async def test_store_expiration_with_timedelta(asyncmy_store: AsyncmyStore) -> None: + """Test session expiration with timedelta.""" + await asyncmy_store.set("expiring_session", b"data", expires_in=timedelta(seconds=1)) + + assert await asyncmy_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await asyncmy_store.get("expiring_session") + assert result is None + + +async def test_store_no_expiration(asyncmy_store: AsyncmyStore) -> None: + """Test session without expiration persists.""" + await asyncmy_store.set("permanent_session", b"data") + + expires_in = await asyncmy_store.expires_in("permanent_session") + assert expires_in is None + + assert await asyncmy_store.exists("permanent_session") + + +async def test_store_expires_in(asyncmy_store: AsyncmyStore) -> None: + """Test expires_in returns correct time.""" + await asyncmy_store.set("timed_session", b"data", expires_in=10) + + expires_in = await asyncmy_store.expires_in("timed_session") + assert expires_in is not None + assert 8 <= expires_in <= 10 + + +async def test_store_expires_in_expired(asyncmy_store: AsyncmyStore) -> None: + """Test expires_in returns 0 for expired session.""" + await asyncmy_store.set("expired_session", b"data", expires_in=1) + + await asyncio.sleep(1.1) + + expires_in = await asyncmy_store.expires_in("expired_session") + assert expires_in == 0 + + +async def test_store_cleanup(asyncmy_store: AsyncmyStore) -> None: + """Test delete_expired removes only expired sessions.""" + await asyncmy_store.set("active_session", b"data", expires_in=60) + await asyncmy_store.set("expired_session_1", b"data", expires_in=1) + await asyncmy_store.set("expired_session_2", b"data", expires_in=1) + await asyncmy_store.set("permanent_session", b"data") + + await asyncio.sleep(1.1) + + count = await asyncmy_store.delete_expired() + assert count == 2 + + assert await asyncmy_store.exists("active_session") + assert await asyncmy_store.exists("permanent_session") + assert not await asyncmy_store.exists("expired_session_1") + assert not await asyncmy_store.exists("expired_session_2") + + +async def test_store_upsert(asyncmy_store: AsyncmyStore) -> None: + """Test updating existing session (UPSERT).""" + await asyncmy_store.set("session_upsert", b"original data") + + result = await asyncmy_store.get("session_upsert") + assert result == b"original data" + + await asyncmy_store.set("session_upsert", b"updated data") + + result = await asyncmy_store.get("session_upsert") + assert result == b"updated data" + + +async def test_store_upsert_with_expiration_change(asyncmy_store: AsyncmyStore) -> None: + """Test updating session expiration.""" + await asyncmy_store.set("session_exp", b"data", expires_in=60) + + expires_in = await asyncmy_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in > 50 + + await asyncmy_store.set("session_exp", b"data", expires_in=10) + + expires_in = await asyncmy_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in <= 10 + + +async def test_store_renew_for(asyncmy_store: AsyncmyStore) -> None: + """Test renewing session expiration on get.""" + await asyncmy_store.set("session_renew", b"data", expires_in=5) + + await asyncio.sleep(3) + + expires_before = await asyncmy_store.expires_in("session_renew") + assert expires_before is not None + assert expires_before <= 2 + + result = await asyncmy_store.get("session_renew", renew_for=10) + assert result == b"data" + + expires_after = await asyncmy_store.expires_in("session_renew") + assert expires_after is not None + assert expires_after > 8 + + +async def test_store_large_data(asyncmy_store: AsyncmyStore) -> None: + """Test storing large session data (>1MB).""" + large_data = b"x" * (1024 * 1024 + 100) + + await asyncmy_store.set("large_session", large_data) + + result = await asyncmy_store.get("large_session") + assert result is not None + assert result == large_data + assert len(result) > 1024 * 1024 + + +async def test_store_delete_all(asyncmy_store: AsyncmyStore) -> None: + """Test delete_all removes all sessions.""" + await asyncmy_store.set("session1", b"data1") + await asyncmy_store.set("session2", b"data2") + await asyncmy_store.set("session3", b"data3") + + assert await asyncmy_store.exists("session1") + assert await asyncmy_store.exists("session2") + assert await asyncmy_store.exists("session3") + + await asyncmy_store.delete_all() + + assert not await asyncmy_store.exists("session1") + assert not await asyncmy_store.exists("session2") + assert not await asyncmy_store.exists("session3") + + +async def test_store_exists(asyncmy_store: AsyncmyStore) -> None: + """Test exists method.""" + assert not await asyncmy_store.exists("test_session") + + await asyncmy_store.set("test_session", b"data") + + assert await asyncmy_store.exists("test_session") + + +async def test_store_context_manager(asyncmy_store: AsyncmyStore) -> None: + """Test store can be used as async context manager.""" + async with asyncmy_store: + await asyncmy_store.set("ctx_session", b"data") + + result = await asyncmy_store.get("ctx_session") + assert result == b"data" diff --git a/tests/integration/test_adapters/test_asyncmy/test_parameter_styles.py b/tests/integration/test_adapters/test_asyncmy/test_parameter_styles.py index fb6b4d3f..658f0e15 100644 --- a/tests/integration/test_adapters/test_asyncmy/test_parameter_styles.py +++ b/tests/integration/test_adapters/test_asyncmy/test_parameter_styles.py @@ -67,7 +67,6 @@ async def asyncmy_parameter_session(mysql_service: MySQLService) -> AsyncGenerat await session.execute_script("DROP TABLE IF EXISTS test_parameter_conversion") -@pytest.mark.asyncio async def test_asyncmy_qmark_to_pyformat_conversion(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test that ? placeholders get converted to %s placeholders.""" driver = asyncmy_parameter_session @@ -82,7 +81,6 @@ async def test_asyncmy_qmark_to_pyformat_conversion(asyncmy_parameter_session: A assert result.data[0]["value"] == 100 -@pytest.mark.asyncio async def test_asyncmy_pyformat_no_conversion_needed(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test that %s placeholders are used directly without conversion (native format).""" driver = asyncmy_parameter_session @@ -99,7 +97,6 @@ async def test_asyncmy_pyformat_no_conversion_needed(asyncmy_parameter_session: assert result.data[0]["value"] == 200 -@pytest.mark.asyncio async def test_asyncmy_named_to_pyformat_conversion(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test that %(name)s placeholders get converted to %s placeholders.""" driver = asyncmy_parameter_session @@ -117,7 +114,6 @@ async def test_asyncmy_named_to_pyformat_conversion(asyncmy_parameter_session: A assert result.data[0]["value"] == 300 -@pytest.mark.asyncio async def test_asyncmy_sql_object_conversion_validation(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test parameter conversion with SQL object containing different parameter styles.""" driver = asyncmy_parameter_session @@ -141,7 +137,6 @@ async def test_asyncmy_sql_object_conversion_validation(asyncmy_parameter_sessio assert "test3" in names -@pytest.mark.asyncio async def test_asyncmy_mixed_parameter_types_conversion(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test conversion with different parameter value types.""" driver = asyncmy_parameter_session @@ -162,7 +157,6 @@ async def test_asyncmy_mixed_parameter_types_conversion(asyncmy_parameter_sessio assert result.data[0]["description"] == "Mixed type test" -@pytest.mark.asyncio async def test_asyncmy_execute_many_parameter_conversion(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test parameter conversion in execute_many operations.""" driver = asyncmy_parameter_session @@ -184,7 +178,6 @@ async def test_asyncmy_execute_many_parameter_conversion(asyncmy_parameter_sessi assert verify_result.data[0]["count"] == 3 -@pytest.mark.asyncio async def test_asyncmy_parameter_conversion_edge_cases(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test edge cases in parameter conversion.""" driver = asyncmy_parameter_session @@ -205,7 +198,6 @@ async def test_asyncmy_parameter_conversion_edge_cases(asyncmy_parameter_session assert result3.data[0]["count"] >= 3 -@pytest.mark.asyncio async def test_asyncmy_parameter_style_consistency_validation(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test that the parameter conversion maintains consistency.""" driver = asyncmy_parameter_session @@ -228,7 +220,6 @@ async def test_asyncmy_parameter_style_consistency_validation(asyncmy_parameter_ assert result_qmark.data[i]["value"] == result_pyformat.data[i]["value"] -@pytest.mark.asyncio async def test_asyncmy_complex_query_parameter_conversion(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test parameter conversion in complex queries with multiple operations.""" driver = asyncmy_parameter_session @@ -260,7 +251,6 @@ async def test_asyncmy_complex_query_parameter_conversion(asyncmy_parameter_sess assert result.data[0]["value"] == 250 -@pytest.mark.asyncio async def test_asyncmy_mysql_parameter_style_specifics(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test MySQL-specific parameter handling requirements.""" driver = asyncmy_parameter_session @@ -291,7 +281,6 @@ async def test_asyncmy_mysql_parameter_style_specifics(asyncmy_parameter_session assert verify_result.data[0]["value"] == 888 -@pytest.mark.asyncio async def test_asyncmy_2phase_parameter_processing(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test the 2-phase parameter processing system specific to AsyncMy/MySQL.""" driver = asyncmy_parameter_session @@ -324,7 +313,6 @@ async def test_asyncmy_2phase_parameter_processing(asyncmy_parameter_session: As assert all(count == consistent_results[0] for count in consistent_results) -@pytest.mark.asyncio async def test_asyncmy_none_parameters_pyformat(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test None values with PYFORMAT (%s) parameter style.""" driver = asyncmy_parameter_session @@ -363,7 +351,6 @@ async def test_asyncmy_none_parameters_pyformat(asyncmy_parameter_session: Async assert row["created_at"] is None -@pytest.mark.asyncio async def test_asyncmy_none_parameters_qmark(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test None values with QMARK (?) parameter style.""" driver = asyncmy_parameter_session @@ -396,7 +383,6 @@ async def test_asyncmy_none_parameters_qmark(asyncmy_parameter_session: AsyncmyD assert row["optional_field"] is None -@pytest.mark.asyncio async def test_asyncmy_none_parameters_named_pyformat(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test None values with named PYFORMAT %(name)s parameter style.""" driver = asyncmy_parameter_session @@ -440,7 +426,6 @@ async def test_asyncmy_none_parameters_named_pyformat(asyncmy_parameter_session: assert row["metadata"] is None -@pytest.mark.asyncio async def test_asyncmy_all_none_parameters(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test when all parameter values are None.""" driver = asyncmy_parameter_session @@ -478,7 +463,6 @@ async def test_asyncmy_all_none_parameters(asyncmy_parameter_session: AsyncmyDri assert row["col4"] is None -@pytest.mark.asyncio async def test_asyncmy_none_with_execute_many(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test None values work correctly with execute_many.""" driver = asyncmy_parameter_session @@ -523,7 +507,6 @@ async def test_asyncmy_none_with_execute_many(asyncmy_parameter_session: Asyncmy assert rows[4]["name"] == "item5" and rows[4]["value"] is None and rows[4]["category"] is None -@pytest.mark.asyncio async def test_asyncmy_none_parameter_count_validation(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test that parameter count mismatches are properly detected with None values. @@ -570,7 +553,6 @@ async def test_asyncmy_none_parameter_count_validation(asyncmy_parameter_session assert any(keyword in error_msg for keyword in ["parameter", "argument", "mismatch", "count"]) -@pytest.mark.asyncio async def test_asyncmy_none_in_where_clauses(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test None values in WHERE clauses work correctly.""" driver = asyncmy_parameter_session @@ -614,7 +596,6 @@ async def test_asyncmy_none_in_where_clauses(asyncmy_parameter_session: AsyncmyD assert len(result2.data) == 4 # All rows because second condition is always true -@pytest.mark.asyncio async def test_asyncmy_none_complex_scenarios(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test complex scenarios with None parameters.""" driver = asyncmy_parameter_session @@ -672,7 +653,6 @@ async def test_asyncmy_none_complex_scenarios(asyncmy_parameter_session: Asyncmy assert row["metadata"] is None -@pytest.mark.asyncio async def test_asyncmy_none_edge_cases(asyncmy_parameter_session: AsyncmyDriver) -> None: """Test edge cases that might reveal None parameter handling bugs.""" driver = asyncmy_parameter_session diff --git a/tests/integration/test_adapters/test_asyncpg/test_execute_many.py b/tests/integration/test_adapters/test_asyncpg/test_execute_many.py index 880dcadb..802b2ff3 100644 --- a/tests/integration/test_adapters/test_asyncpg/test_execute_many.py +++ b/tests/integration/test_adapters/test_asyncpg/test_execute_many.py @@ -45,7 +45,6 @@ async def asyncpg_batch_session(postgres_service: PostgresService) -> "AsyncGene await config.close_pool() -@pytest.mark.asyncio async def test_asyncpg_execute_many_basic(asyncpg_batch_session: AsyncpgDriver) -> None: """Test basic execute_many with AsyncPG.""" parameters = [ @@ -68,7 +67,6 @@ async def test_asyncpg_execute_many_basic(asyncpg_batch_session: AsyncpgDriver) assert count_result[0]["count"] == 5 -@pytest.mark.asyncio async def test_asyncpg_execute_many_update(asyncpg_batch_session: AsyncpgDriver) -> None: """Test execute_many for UPDATE operations with AsyncPG.""" @@ -90,7 +88,6 @@ async def test_asyncpg_execute_many_update(asyncpg_batch_session: AsyncpgDriver) assert all(row["value"] in (100, 200, 300) for row in check_result) -@pytest.mark.asyncio async def test_asyncpg_execute_many_empty(asyncpg_batch_session: AsyncpgDriver) -> None: """Test execute_many with empty parameter list on AsyncPG.""" result = await asyncpg_batch_session.execute_many( @@ -104,7 +101,6 @@ async def test_asyncpg_execute_many_empty(asyncpg_batch_session: AsyncpgDriver) assert count_result[0]["count"] == 0 -@pytest.mark.asyncio async def test_asyncpg_execute_many_mixed_types(asyncpg_batch_session: AsyncpgDriver) -> None: """Test execute_many with mixed parameter types on AsyncPG.""" parameters = [ @@ -129,7 +125,6 @@ async def test_asyncpg_execute_many_mixed_types(asyncpg_batch_session: AsyncpgDr assert negative_result[0]["value"] == -50 -@pytest.mark.asyncio async def test_asyncpg_execute_many_delete(asyncpg_batch_session: AsyncpgDriver) -> None: """Test execute_many for DELETE operations with AsyncPG.""" @@ -158,7 +153,6 @@ async def test_asyncpg_execute_many_delete(asyncpg_batch_session: AsyncpgDriver) assert remaining_names == ["Delete 3", "Keep 1"] -@pytest.mark.asyncio async def test_asyncpg_execute_many_large_batch(asyncpg_batch_session: AsyncpgDriver) -> None: """Test execute_many with large batch size on AsyncPG.""" @@ -182,7 +176,6 @@ async def test_asyncpg_execute_many_large_batch(asyncpg_batch_session: AsyncpgDr assert sample_result[2]["value"] == 9990 -@pytest.mark.asyncio async def test_asyncpg_execute_many_with_sql_object(asyncpg_batch_session: AsyncpgDriver) -> None: """Test execute_many with SQL object on AsyncPG.""" from sqlspec.core.statement import SQL @@ -201,7 +194,6 @@ async def test_asyncpg_execute_many_with_sql_object(asyncpg_batch_session: Async assert check_result[0]["count"] == 3 -@pytest.mark.asyncio async def test_asyncpg_execute_many_with_returning(asyncpg_batch_session: AsyncpgDriver) -> None: """Test execute_many with RETURNING clause on AsyncPG.""" parameters = [("Return 1", 111, "RET"), ("Return 2", 222, "RET"), ("Return 3", 333, "RET")] @@ -227,7 +219,6 @@ async def test_asyncpg_execute_many_with_returning(asyncpg_batch_session: Asyncp assert check_result[0]["count"] == 3 -@pytest.mark.asyncio async def test_asyncpg_execute_many_with_arrays(asyncpg_batch_session: AsyncpgDriver) -> None: """Test execute_many with PostgreSQL array types on AsyncPG.""" @@ -262,7 +253,6 @@ async def test_asyncpg_execute_many_with_arrays(asyncpg_batch_session: AsyncpgDr assert check_result[2]["tag_count"] == 3 -@pytest.mark.asyncio async def test_asyncpg_execute_many_with_json(asyncpg_batch_session: AsyncpgDriver) -> None: """Test execute_many with JSON data on AsyncPG.""" await asyncpg_batch_session.execute_script(""" diff --git a/tests/integration/test_adapters/test_asyncpg/test_extensions/__init__.py b/tests/integration/test_adapters/test_asyncpg/test_extensions/__init__.py new file mode 100644 index 00000000..37955c08 --- /dev/null +++ b/tests/integration/test_adapters/test_asyncpg/test_extensions/__init__.py @@ -0,0 +1 @@ +"""AsyncPG extension integration tests.""" diff --git a/tests/integration/test_adapters/test_asyncpg/test_extensions/test_litestar/__init__.py b/tests/integration/test_adapters/test_asyncpg/test_extensions/test_litestar/__init__.py new file mode 100644 index 00000000..3ae3474c --- /dev/null +++ b/tests/integration/test_adapters/test_asyncpg/test_extensions/test_litestar/__init__.py @@ -0,0 +1 @@ +"""Integration tests for AsyncPG Litestar extensions.""" diff --git a/tests/integration/test_adapters/test_asyncpg/test_extensions/test_litestar/test_store.py b/tests/integration/test_adapters/test_asyncpg/test_extensions/test_litestar/test_store.py new file mode 100644 index 00000000..91517eea --- /dev/null +++ b/tests/integration/test_adapters/test_asyncpg/test_extensions/test_litestar/test_store.py @@ -0,0 +1,245 @@ +"""Integration tests for AsyncPG session store.""" + +import asyncio +from collections.abc import AsyncGenerator +from datetime import timedelta + +import pytest +from pytest_databases.docker.postgres import PostgresService + +from sqlspec.adapters.asyncpg.config import AsyncpgConfig +from sqlspec.adapters.asyncpg.litestar.store import AsyncpgStore + +pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.asyncpg, pytest.mark.integration] + + +@pytest.fixture +async def asyncpg_store(postgres_service: PostgresService) -> "AsyncGenerator[AsyncpgStore, None]": + """Create AsyncPG store with test database.""" + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + } + ) + store = AsyncpgStore(config, table_name="test_sessions") + try: + await store.create_table() + yield store + await store.delete_all() + finally: + if config.pool_instance: + await config.close_pool() + + +async def test_store_create_table(asyncpg_store: AsyncpgStore) -> None: + """Test table creation.""" + assert asyncpg_store.table_name == "test_sessions" + + +async def test_store_set_and_get(asyncpg_store: AsyncpgStore) -> None: + """Test basic set and get operations.""" + test_data = b"test session data" + await asyncpg_store.set("session_123", test_data) + + result = await asyncpg_store.get("session_123") + assert result == test_data + + +async def test_store_get_nonexistent(asyncpg_store: AsyncpgStore) -> None: + """Test getting a non-existent session returns None.""" + result = await asyncpg_store.get("nonexistent") + assert result is None + + +async def test_store_set_with_string_value(asyncpg_store: AsyncpgStore) -> None: + """Test setting a string value (should be converted to bytes).""" + await asyncpg_store.set("session_str", "string data") + + result = await asyncpg_store.get("session_str") + assert result == b"string data" + + +async def test_store_delete(asyncpg_store: AsyncpgStore) -> None: + """Test delete operation.""" + await asyncpg_store.set("session_to_delete", b"data") + + assert await asyncpg_store.exists("session_to_delete") + + await asyncpg_store.delete("session_to_delete") + + assert not await asyncpg_store.exists("session_to_delete") + assert await asyncpg_store.get("session_to_delete") is None + + +async def test_store_delete_nonexistent(asyncpg_store: AsyncpgStore) -> None: + """Test deleting a non-existent session is a no-op.""" + await asyncpg_store.delete("nonexistent") + + +async def test_store_expiration_with_int(asyncpg_store: AsyncpgStore) -> None: + """Test session expiration with integer seconds.""" + await asyncpg_store.set("expiring_session", b"data", expires_in=1) + + assert await asyncpg_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await asyncpg_store.get("expiring_session") + assert result is None + assert not await asyncpg_store.exists("expiring_session") + + +async def test_store_expiration_with_timedelta(asyncpg_store: AsyncpgStore) -> None: + """Test session expiration with timedelta.""" + await asyncpg_store.set("expiring_session", b"data", expires_in=timedelta(seconds=1)) + + assert await asyncpg_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await asyncpg_store.get("expiring_session") + assert result is None + + +async def test_store_no_expiration(asyncpg_store: AsyncpgStore) -> None: + """Test session without expiration persists.""" + await asyncpg_store.set("permanent_session", b"data") + + expires_in = await asyncpg_store.expires_in("permanent_session") + assert expires_in is None + + assert await asyncpg_store.exists("permanent_session") + + +async def test_store_expires_in(asyncpg_store: AsyncpgStore) -> None: + """Test expires_in returns correct time.""" + await asyncpg_store.set("timed_session", b"data", expires_in=10) + + expires_in = await asyncpg_store.expires_in("timed_session") + assert expires_in is not None + assert 8 <= expires_in <= 10 + + +async def test_store_expires_in_expired(asyncpg_store: AsyncpgStore) -> None: + """Test expires_in returns 0 for expired session.""" + await asyncpg_store.set("expired_session", b"data", expires_in=1) + + await asyncio.sleep(1.1) + + expires_in = await asyncpg_store.expires_in("expired_session") + assert expires_in == 0 + + +async def test_store_cleanup(asyncpg_store: AsyncpgStore) -> None: + """Test delete_expired removes only expired sessions.""" + await asyncpg_store.set("active_session", b"data", expires_in=60) + await asyncpg_store.set("expired_session_1", b"data", expires_in=1) + await asyncpg_store.set("expired_session_2", b"data", expires_in=1) + await asyncpg_store.set("permanent_session", b"data") + + await asyncio.sleep(1.1) + + count = await asyncpg_store.delete_expired() + assert count == 2 + + assert await asyncpg_store.exists("active_session") + assert await asyncpg_store.exists("permanent_session") + assert not await asyncpg_store.exists("expired_session_1") + assert not await asyncpg_store.exists("expired_session_2") + + +async def test_store_upsert(asyncpg_store: AsyncpgStore) -> None: + """Test updating existing session (UPSERT).""" + await asyncpg_store.set("session_upsert", b"original data") + + result = await asyncpg_store.get("session_upsert") + assert result == b"original data" + + await asyncpg_store.set("session_upsert", b"updated data") + + result = await asyncpg_store.get("session_upsert") + assert result == b"updated data" + + +async def test_store_upsert_with_expiration_change(asyncpg_store: AsyncpgStore) -> None: + """Test updating session expiration.""" + await asyncpg_store.set("session_exp", b"data", expires_in=60) + + expires_in = await asyncpg_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in > 50 + + await asyncpg_store.set("session_exp", b"data", expires_in=10) + + expires_in = await asyncpg_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in <= 10 + + +async def test_store_renew_for(asyncpg_store: AsyncpgStore) -> None: + """Test renewing session expiration on get.""" + await asyncpg_store.set("session_renew", b"data", expires_in=5) + + await asyncio.sleep(3) + + expires_before = await asyncpg_store.expires_in("session_renew") + assert expires_before is not None + assert expires_before <= 2 + + result = await asyncpg_store.get("session_renew", renew_for=10) + assert result == b"data" + + expires_after = await asyncpg_store.expires_in("session_renew") + assert expires_after is not None + assert expires_after > 8 + + +async def test_store_large_data(asyncpg_store: AsyncpgStore) -> None: + """Test storing large session data (>1MB).""" + large_data = b"x" * (1024 * 1024 + 100) + + await asyncpg_store.set("large_session", large_data) + + result = await asyncpg_store.get("large_session") + assert result is not None + assert result == large_data + assert len(result) > 1024 * 1024 + + +async def test_store_delete_all(asyncpg_store: AsyncpgStore) -> None: + """Test delete_all removes all sessions.""" + await asyncpg_store.set("session1", b"data1") + await asyncpg_store.set("session2", b"data2") + await asyncpg_store.set("session3", b"data3") + + assert await asyncpg_store.exists("session1") + assert await asyncpg_store.exists("session2") + assert await asyncpg_store.exists("session3") + + await asyncpg_store.delete_all() + + assert not await asyncpg_store.exists("session1") + assert not await asyncpg_store.exists("session2") + assert not await asyncpg_store.exists("session3") + + +async def test_store_exists(asyncpg_store: AsyncpgStore) -> None: + """Test exists method.""" + assert not await asyncpg_store.exists("test_session") + + await asyncpg_store.set("test_session", b"data") + + assert await asyncpg_store.exists("test_session") + + +async def test_store_context_manager(asyncpg_store: AsyncpgStore) -> None: + """Test store can be used as async context manager.""" + async with asyncpg_store: + await asyncpg_store.set("ctx_session", b"data") + + result = await asyncpg_store.get("ctx_session") + assert result == b"data" diff --git a/tests/integration/test_adapters/test_asyncpg/test_parameter_styles.py b/tests/integration/test_adapters/test_asyncpg/test_parameter_styles.py index 83f727d2..e2ded4ae 100644 --- a/tests/integration/test_adapters/test_asyncpg/test_parameter_styles.py +++ b/tests/integration/test_adapters/test_asyncpg/test_parameter_styles.py @@ -56,7 +56,6 @@ async def asyncpg_parameters_session(postgres_service: PostgresService) -> "Asyn await config.close_pool() -@pytest.mark.asyncio @pytest.mark.parametrize("parameters,expected_count", [(("test1",), 1), (["test1"], 1)]) async def test_asyncpg_numeric_parameter_types( asyncpg_parameters_session: AsyncpgDriver, parameters: Any, expected_count: int @@ -71,7 +70,6 @@ async def test_asyncpg_numeric_parameter_types( assert result[0]["name"] == "test1" -@pytest.mark.asyncio async def test_asyncpg_numeric_parameter_style(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test PostgreSQL numeric parameter style with AsyncPG.""" result = await asyncpg_parameters_session.execute("SELECT * FROM test_parameters WHERE name = $1", ("test1",)) @@ -82,7 +80,6 @@ async def test_asyncpg_numeric_parameter_style(asyncpg_parameters_session: Async assert result[0]["name"] == "test1" -@pytest.mark.asyncio async def test_asyncpg_multiple_parameters_numeric(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test queries with multiple parameters using numeric style.""" result = await asyncpg_parameters_session.execute( @@ -97,7 +94,6 @@ async def test_asyncpg_multiple_parameters_numeric(asyncpg_parameters_session: A assert result[2]["value"] == 100 -@pytest.mark.asyncio async def test_asyncpg_null_parameters(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test handling of NULL parameters on AsyncPG.""" @@ -120,7 +116,6 @@ async def test_asyncpg_null_parameters(asyncpg_parameters_session: AsyncpgDriver assert null_result[0]["description"] is None -@pytest.mark.asyncio async def test_asyncpg_parameter_escaping(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test parameter escaping prevents SQL injection.""" @@ -138,7 +133,6 @@ async def test_asyncpg_parameter_escaping(asyncpg_parameters_session: AsyncpgDri assert count_result[0]["count"] >= 3 -@pytest.mark.asyncio async def test_asyncpg_parameter_with_like(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test parameters with LIKE operations.""" result = await asyncpg_parameters_session.execute("SELECT * FROM test_parameters WHERE name LIKE $1", ("test%",)) @@ -154,7 +148,6 @@ async def test_asyncpg_parameter_with_like(asyncpg_parameters_session: AsyncpgDr assert specific_result[0]["name"] == "test1" -@pytest.mark.asyncio async def test_asyncpg_parameter_with_any_array(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test parameters with PostgreSQL ANY and arrays.""" @@ -175,7 +168,6 @@ async def test_asyncpg_parameter_with_any_array(asyncpg_parameters_session: Asyn assert result[2]["name"] == "test1" -@pytest.mark.asyncio async def test_asyncpg_parameter_with_sql_object(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test parameters with SQL object.""" from sqlspec.core.statement import SQL @@ -189,7 +181,6 @@ async def test_asyncpg_parameter_with_sql_object(asyncpg_parameters_session: Asy assert all(row["value"] > 150 for row in result) -@pytest.mark.asyncio async def test_asyncpg_parameter_data_types(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test different parameter data types with AsyncPG.""" @@ -226,7 +217,6 @@ async def test_asyncpg_parameter_data_types(asyncpg_parameters_session: AsyncpgD assert result[0]["array_val"] == [1, 2, 3] -@pytest.mark.asyncio async def test_asyncpg_parameter_edge_cases(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test edge cases for AsyncPG parameters.""" @@ -250,7 +240,6 @@ async def test_asyncpg_parameter_edge_cases(asyncpg_parameters_session: AsyncpgD assert len(long_result[0]["description"]) == 1000 -@pytest.mark.asyncio async def test_asyncpg_parameter_with_postgresql_functions(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test parameters with PostgreSQL functions.""" @@ -276,7 +265,6 @@ async def test_asyncpg_parameter_with_postgresql_functions(asyncpg_parameters_se assert multiplied_value == expected -@pytest.mark.asyncio async def test_asyncpg_parameter_with_json(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test parameters with PostgreSQL JSON operations.""" @@ -309,7 +297,6 @@ async def test_asyncpg_parameter_with_json(asyncpg_parameters_session: AsyncpgDr assert all(row["type"] == "test" for row in result) -@pytest.mark.asyncio async def test_asyncpg_parameter_with_arrays(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test parameters with PostgreSQL array operations.""" @@ -345,7 +332,6 @@ async def test_asyncpg_parameter_with_arrays(asyncpg_parameters_session: Asyncpg assert len(length_result) == 2 -@pytest.mark.asyncio async def test_asyncpg_parameter_with_window_functions(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test parameters with PostgreSQL window functions.""" @@ -381,7 +367,6 @@ async def test_asyncpg_parameter_with_window_functions(asyncpg_parameters_sessio assert group_a_rows[1]["row_num"] == 2 -@pytest.mark.asyncio async def test_asyncpg_none_values_in_named_parameters(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test that None values in named parameters are handled correctly.""" await asyncpg_parameters_session.execute(""" @@ -444,7 +429,6 @@ async def test_asyncpg_none_values_in_named_parameters(asyncpg_parameters_sessio await asyncpg_parameters_session.execute("DROP TABLE test_none_values") -@pytest.mark.asyncio async def test_asyncpg_all_none_parameters(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test when all parameter values are None.""" await asyncpg_parameters_session.execute(""" @@ -477,7 +461,6 @@ async def test_asyncpg_all_none_parameters(asyncpg_parameters_session: AsyncpgDr await asyncpg_parameters_session.execute("DROP TABLE test_all_none") -@pytest.mark.asyncio async def test_asyncpg_jsonb_none_parameters(asyncpg_parameters_session: AsyncpgDriver) -> None: """Test JSONB column None parameter handling comprehensively.""" diff --git a/tests/integration/test_adapters/test_bigquery/conftest.py b/tests/integration/test_adapters/test_bigquery/conftest.py index 8af09474..5150ab62 100644 --- a/tests/integration/test_adapters/test_bigquery/conftest.py +++ b/tests/integration/test_adapters/test_bigquery/conftest.py @@ -37,11 +37,9 @@ def bigquery_config(bigquery_service: "BigQueryService", table_schema_prefix: st @pytest.fixture def bigquery_session(bigquery_config: BigQueryConfig) -> Generator[BigQueryDriver, Any, None]: """Create a BigQuery sync session.""" - try: - with bigquery_config.provide_session() as session: - yield session - finally: - pass + + with bigquery_config.provide_session() as session: + yield session @pytest.fixture diff --git a/tests/integration/test_adapters/test_bigquery/test_extensions/__init__.py b/tests/integration/test_adapters/test_bigquery/test_extensions/__init__.py new file mode 100644 index 00000000..79e69e36 --- /dev/null +++ b/tests/integration/test_adapters/test_bigquery/test_extensions/__init__.py @@ -0,0 +1 @@ +"""BigQuery extensions integration tests.""" diff --git a/tests/integration/test_adapters/test_duckdb/test_extensions/__init__.py b/tests/integration/test_adapters/test_duckdb/test_extensions/__init__.py new file mode 100644 index 00000000..9cb34069 --- /dev/null +++ b/tests/integration/test_adapters/test_duckdb/test_extensions/__init__.py @@ -0,0 +1 @@ +"""Tests for DuckDB extensions.""" diff --git a/tests/integration/test_adapters/test_duckdb/test_extensions/test_litestar/__init__.py b/tests/integration/test_adapters/test_duckdb/test_extensions/test_litestar/__init__.py new file mode 100644 index 00000000..4978f6fd --- /dev/null +++ b/tests/integration/test_adapters/test_duckdb/test_extensions/test_litestar/__init__.py @@ -0,0 +1 @@ +"""Tests for DuckDB Litestar integration.""" diff --git a/tests/integration/test_adapters/test_duckdb/test_extensions/test_litestar/test_store.py b/tests/integration/test_adapters/test_duckdb/test_extensions/test_litestar/test_store.py new file mode 100644 index 00000000..265f6bef --- /dev/null +++ b/tests/integration/test_adapters/test_duckdb/test_extensions/test_litestar/test_store.py @@ -0,0 +1,270 @@ +"""Integration tests for DuckDB sync session store.""" + +import asyncio +from collections.abc import AsyncGenerator +from datetime import timedelta +from pathlib import Path + +import pytest + +from sqlspec.adapters.duckdb.config import DuckDBConfig +from sqlspec.adapters.duckdb.litestar.store import DuckdbStore + +pytestmark = [pytest.mark.duckdb, pytest.mark.integration] + + +@pytest.fixture +async def duckdb_store(tmp_path: Path, worker_id: str) -> AsyncGenerator[DuckdbStore, None]: + """Create DuckDB store with temporary file-based database. + + Args: + tmp_path: Pytest fixture providing unique temporary directory per test. + worker_id: Pytest-xdist fixture providing unique worker identifier. + + Note: + DuckDB in-memory databases are connection-local, not process-wide. + Since the thread-local connection pool creates separate connection + objects for each thread, we must use a file-based database to ensure + all threads share the same data. + + Worker ID ensures parallel pytest-xdist workers use separate database + files, preventing file locking conflicts. + """ + db_path = tmp_path / f"test_sessions_{worker_id}.duckdb" + try: + config = DuckDBConfig(pool_config={"database": str(db_path)}) + store = DuckdbStore(config, table_name="test_sessions") + await store.create_table() + yield store + await store.delete_all() + finally: + if db_path.exists(): + db_path.unlink() + + +async def test_store_create_table(duckdb_store: DuckdbStore) -> None: + """Test table creation.""" + assert duckdb_store.table_name == "test_sessions" + + +async def test_store_set_and_get(duckdb_store: DuckdbStore) -> None: + """Test basic set and get operations.""" + test_data = b"test session data" + await duckdb_store.set("session_123", test_data) + + result = await duckdb_store.get("session_123") + assert result == test_data + + +async def test_store_get_nonexistent(duckdb_store: DuckdbStore) -> None: + """Test getting a non-existent session returns None.""" + result = await duckdb_store.get("nonexistent") + assert result is None + + +async def test_store_set_with_string_value(duckdb_store: DuckdbStore) -> None: + """Test setting a string value (should be converted to bytes).""" + await duckdb_store.set("session_str", "string data") + + result = await duckdb_store.get("session_str") + assert result == b"string data" + + +async def test_store_delete(duckdb_store: DuckdbStore) -> None: + """Test delete operation.""" + await duckdb_store.set("session_to_delete", b"data") + + assert await duckdb_store.exists("session_to_delete") + + await duckdb_store.delete("session_to_delete") + + assert not await duckdb_store.exists("session_to_delete") + assert await duckdb_store.get("session_to_delete") is None + + +async def test_store_delete_nonexistent(duckdb_store: DuckdbStore) -> None: + """Test deleting a non-existent session is a no-op.""" + await duckdb_store.delete("nonexistent") + + +async def test_store_expiration_with_int(duckdb_store: DuckdbStore) -> None: + """Test session expiration with integer seconds.""" + await duckdb_store.set("expiring_session", b"data", expires_in=1) + + assert await duckdb_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await duckdb_store.get("expiring_session") + assert result is None + assert not await duckdb_store.exists("expiring_session") + + +async def test_store_expiration_with_timedelta(duckdb_store: DuckdbStore) -> None: + """Test session expiration with timedelta.""" + await duckdb_store.set("expiring_session", b"data", expires_in=timedelta(seconds=1)) + + assert await duckdb_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await duckdb_store.get("expiring_session") + assert result is None + + +async def test_store_no_expiration(duckdb_store: DuckdbStore) -> None: + """Test session without expiration persists.""" + await duckdb_store.set("permanent_session", b"data") + + expires_in = await duckdb_store.expires_in("permanent_session") + assert expires_in is None + + assert await duckdb_store.exists("permanent_session") + + +async def test_store_expires_in(duckdb_store: DuckdbStore) -> None: + """Test expires_in returns correct time.""" + await duckdb_store.set("timed_session", b"data", expires_in=10) + + expires_in = await duckdb_store.expires_in("timed_session") + assert expires_in is not None + assert 8 <= expires_in <= 10 + + +async def test_store_expires_in_expired(duckdb_store: DuckdbStore) -> None: + """Test expires_in returns 0 for expired session.""" + await duckdb_store.set("expired_session", b"data", expires_in=1) + + await asyncio.sleep(1.1) + + expires_in = await duckdb_store.expires_in("expired_session") + assert expires_in == 0 + + +async def test_store_cleanup(duckdb_store: DuckdbStore) -> None: + """Test delete_expired removes only expired sessions.""" + await duckdb_store.set("active_session", b"data", expires_in=60) + await duckdb_store.set("expired_session_1", b"data", expires_in=1) + await duckdb_store.set("expired_session_2", b"data", expires_in=1) + await duckdb_store.set("permanent_session", b"data") + + await asyncio.sleep(1.1) + + count = await duckdb_store.delete_expired() + assert count == 2 + + assert await duckdb_store.exists("active_session") + assert await duckdb_store.exists("permanent_session") + assert not await duckdb_store.exists("expired_session_1") + assert not await duckdb_store.exists("expired_session_2") + + +async def test_store_upsert(duckdb_store: DuckdbStore) -> None: + """Test updating existing session (UPSERT).""" + await duckdb_store.set("session_upsert", b"original data") + + result = await duckdb_store.get("session_upsert") + assert result == b"original data" + + await duckdb_store.set("session_upsert", b"updated data") + + result = await duckdb_store.get("session_upsert") + assert result == b"updated data" + + +async def test_store_upsert_with_expiration_change(duckdb_store: DuckdbStore) -> None: + """Test updating session expiration.""" + await duckdb_store.set("session_exp", b"data", expires_in=60) + + expires_in = await duckdb_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in > 50 + + await duckdb_store.set("session_exp", b"data", expires_in=10) + + expires_in = await duckdb_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in <= 10 + + +async def test_store_renew_for(duckdb_store: DuckdbStore) -> None: + """Test renewing session expiration on get.""" + await duckdb_store.set("session_renew", b"data", expires_in=5) + + await asyncio.sleep(3) + + expires_before = await duckdb_store.expires_in("session_renew") + assert expires_before is not None + assert expires_before <= 2 + + result = await duckdb_store.get("session_renew", renew_for=10) + assert result == b"data" + + expires_after = await duckdb_store.expires_in("session_renew") + assert expires_after is not None + assert expires_after > 8 + + +async def test_store_large_data(duckdb_store: DuckdbStore) -> None: + """Test storing large session data (>1MB).""" + large_data = b"x" * (1024 * 1024 + 100) + + await duckdb_store.set("large_session", large_data) + + result = await duckdb_store.get("large_session") + assert result is not None + assert result == large_data + assert len(result) > 1024 * 1024 + + +async def test_store_delete_all(duckdb_store: DuckdbStore) -> None: + """Test delete_all removes all sessions.""" + await duckdb_store.set("session1", b"data1") + await duckdb_store.set("session2", b"data2") + await duckdb_store.set("session3", b"data3") + + assert await duckdb_store.exists("session1") + assert await duckdb_store.exists("session2") + assert await duckdb_store.exists("session3") + + await duckdb_store.delete_all() + + assert not await duckdb_store.exists("session1") + assert not await duckdb_store.exists("session2") + assert not await duckdb_store.exists("session3") + + +async def test_store_exists(duckdb_store: DuckdbStore) -> None: + """Test exists method.""" + assert not await duckdb_store.exists("test_session") + + await duckdb_store.set("test_session", b"data") + + assert await duckdb_store.exists("test_session") + + +async def test_store_context_manager(duckdb_store: DuckdbStore) -> None: + """Test store can be used as async context manager.""" + async with duckdb_store: + await duckdb_store.set("ctx_session", b"data") + + result = await duckdb_store.get("ctx_session") + assert result == b"data" + + +async def test_sync_to_thread_concurrency(duckdb_store: DuckdbStore) -> None: + """Test concurrent access via sync_to_thread wrapper. + + DuckDB has write serialization, so we test sequential writes + followed by concurrent reads which is the typical session store pattern. + """ + for i in range(10): + await duckdb_store.set(f"session_{i}", f"data_{i}".encode()) + + async def read_session(session_id: int) -> "bytes | None": + return await duckdb_store.get(f"session_{session_id}") + + results = await asyncio.gather(*[read_session(i) for i in range(10)]) + + for i, result in enumerate(results): + assert result == f"data_{i}".encode() diff --git a/tests/integration/test_adapters/test_oracledb/test_execute_many.py b/tests/integration/test_adapters/test_oracledb/test_execute_many.py index 13326ed6..78ba974a 100644 --- a/tests/integration/test_adapters/test_oracledb/test_execute_many.py +++ b/tests/integration/test_adapters/test_oracledb/test_execute_many.py @@ -68,7 +68,6 @@ def test_sync_execute_many_insert_batch(oracle_sync_session: OracleSyncDriver) - ) -@pytest.mark.asyncio(loop_scope="function") async def test_async_execute_many_update_batch(oracle_async_session: OracleAsyncDriver) -> None: """Test execute_many with batch UPDATE operations.""" @@ -192,7 +191,6 @@ def test_sync_execute_many_with_named_parameters(oracle_sync_session: OracleSync ) -@pytest.mark.asyncio(loop_scope="function") async def test_async_execute_many_with_sequences(oracle_async_session: OracleAsyncDriver) -> None: """Test execute_many with Oracle sequences for auto-incrementing IDs.""" diff --git a/tests/integration/test_adapters/test_oracledb/test_extensions/__init__.py b/tests/integration/test_adapters/test_oracledb/test_extensions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/__init__.py b/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_async.py b/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_async.py new file mode 100644 index 00000000..a1829a1b --- /dev/null +++ b/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_async.py @@ -0,0 +1,247 @@ +"""Integration tests for Oracle session store.""" + +import asyncio +from collections.abc import AsyncGenerator +from datetime import timedelta + +import pytest +from pytest_databases.docker.oracle import OracleService + +from sqlspec.adapters.oracledb.config import OracleAsyncConfig +from sqlspec.adapters.oracledb.litestar.store import OracleAsyncStore + +pytestmark = pytest.mark.xdist_group("oracle") + + +@pytest.fixture +async def oracle_store(oracle_23ai_service: OracleService) -> "AsyncGenerator[OracleAsyncStore, None]": + """Create Oracle store with test database.""" + config = OracleAsyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + "min": 1, + "max": 5, + } + ) + store = OracleAsyncStore(config, table_name="test_sessions") + try: + await store.create_table() + yield store + await store.delete_all() + finally: + if config.pool_instance: + await config.close_pool() + + +async def test_store_create_table(oracle_store: OracleAsyncStore) -> None: + """Test table creation.""" + assert oracle_store.table_name == "test_sessions" + + +async def test_store_set_and_get(oracle_store: OracleAsyncStore) -> None: + """Test basic set and get operations.""" + test_data = b"test session data" + await oracle_store.set("session_123", test_data) + + result = await oracle_store.get("session_123") + assert result == test_data + + +async def test_store_get_nonexistent(oracle_store: OracleAsyncStore) -> None: + """Test getting a non-existent session returns None.""" + result = await oracle_store.get("nonexistent") + assert result is None + + +async def test_store_set_with_string_value(oracle_store: OracleAsyncStore) -> None: + """Test setting a string value (should be converted to bytes).""" + await oracle_store.set("session_str", "string data") + + result = await oracle_store.get("session_str") + assert result == b"string data" + + +async def test_store_delete(oracle_store: OracleAsyncStore) -> None: + """Test delete operation.""" + await oracle_store.set("session_to_delete", b"data") + + assert await oracle_store.exists("session_to_delete") + + await oracle_store.delete("session_to_delete") + + assert not await oracle_store.exists("session_to_delete") + assert await oracle_store.get("session_to_delete") is None + + +async def test_store_delete_nonexistent(oracle_store: OracleAsyncStore) -> None: + """Test deleting a non-existent session is a no-op.""" + await oracle_store.delete("nonexistent") + + +async def test_store_expiration_with_int(oracle_store: OracleAsyncStore) -> None: + """Test session expiration with integer seconds.""" + await oracle_store.set("expiring_session", b"data", expires_in=1) + + assert await oracle_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await oracle_store.get("expiring_session") + assert result is None + assert not await oracle_store.exists("expiring_session") + + +async def test_store_expiration_with_timedelta(oracle_store: OracleAsyncStore) -> None: + """Test session expiration with timedelta.""" + await oracle_store.set("expiring_session", b"data", expires_in=timedelta(seconds=1)) + + assert await oracle_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await oracle_store.get("expiring_session") + assert result is None + + +async def test_store_no_expiration(oracle_store: OracleAsyncStore) -> None: + """Test session without expiration persists.""" + await oracle_store.set("permanent_session", b"data") + + expires_in = await oracle_store.expires_in("permanent_session") + assert expires_in is None + + assert await oracle_store.exists("permanent_session") + + +async def test_store_expires_in(oracle_store: OracleAsyncStore) -> None: + """Test expires_in returns correct time.""" + await oracle_store.set("timed_session", b"data", expires_in=10) + + expires_in = await oracle_store.expires_in("timed_session") + assert expires_in is not None + assert 8 <= expires_in <= 10 + + +async def test_store_expires_in_expired(oracle_store: OracleAsyncStore) -> None: + """Test expires_in returns 0 for expired session.""" + await oracle_store.set("expired_session", b"data", expires_in=1) + + await asyncio.sleep(1.1) + + expires_in = await oracle_store.expires_in("expired_session") + assert expires_in == 0 + + +async def test_store_cleanup(oracle_store: OracleAsyncStore) -> None: + """Test delete_expired removes only expired sessions.""" + await oracle_store.set("active_session", b"data", expires_in=60) + await oracle_store.set("expired_session_1", b"data", expires_in=1) + await oracle_store.set("expired_session_2", b"data", expires_in=1) + await oracle_store.set("permanent_session", b"data") + + await asyncio.sleep(1.1) + + count = await oracle_store.delete_expired() + assert count == 2 + + assert await oracle_store.exists("active_session") + assert await oracle_store.exists("permanent_session") + assert not await oracle_store.exists("expired_session_1") + assert not await oracle_store.exists("expired_session_2") + + +async def test_store_upsert(oracle_store: OracleAsyncStore) -> None: + """Test updating existing session (UPSERT).""" + await oracle_store.set("session_upsert", b"original data") + + result = await oracle_store.get("session_upsert") + assert result == b"original data" + + await oracle_store.set("session_upsert", b"updated data") + + result = await oracle_store.get("session_upsert") + assert result == b"updated data" + + +async def test_store_upsert_with_expiration_change(oracle_store: OracleAsyncStore) -> None: + """Test updating session expiration.""" + await oracle_store.set("session_exp", b"data", expires_in=60) + + expires_in = await oracle_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in > 50 + + await oracle_store.set("session_exp", b"data", expires_in=10) + + expires_in = await oracle_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in <= 10 + + +async def test_store_renew_for(oracle_store: OracleAsyncStore) -> None: + """Test renewing session expiration on get.""" + await oracle_store.set("session_renew", b"data", expires_in=5) + + await asyncio.sleep(3) + + expires_before = await oracle_store.expires_in("session_renew") + assert expires_before is not None + assert expires_before <= 2 + + result = await oracle_store.get("session_renew", renew_for=10) + assert result == b"data" + + expires_after = await oracle_store.expires_in("session_renew") + assert expires_after is not None + assert expires_after > 8 + + +async def test_store_large_data(oracle_store: OracleAsyncStore) -> None: + """Test storing large session data (>1MB).""" + large_data = b"x" * (1024 * 1024 + 100) + + await oracle_store.set("large_session", large_data) + + result = await oracle_store.get("large_session") + assert result is not None + assert result == large_data + assert len(result) > 1024 * 1024 + + +async def test_store_delete_all(oracle_store: OracleAsyncStore) -> None: + """Test delete_all removes all sessions.""" + await oracle_store.set("session1", b"data1") + await oracle_store.set("session2", b"data2") + await oracle_store.set("session3", b"data3") + + assert await oracle_store.exists("session1") + assert await oracle_store.exists("session2") + assert await oracle_store.exists("session3") + + await oracle_store.delete_all() + + assert not await oracle_store.exists("session1") + assert not await oracle_store.exists("session2") + assert not await oracle_store.exists("session3") + + +async def test_store_exists(oracle_store: OracleAsyncStore) -> None: + """Test exists method.""" + assert not await oracle_store.exists("test_session") + + await oracle_store.set("test_session", b"data") + + assert await oracle_store.exists("test_session") + + +async def test_store_context_manager(oracle_store: OracleAsyncStore) -> None: + """Test store can be used as async context manager.""" + async with oracle_store: + await oracle_store.set("ctx_session", b"data") + + result = await oracle_store.get("ctx_session") + assert result == b"data" diff --git a/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_sync.py b/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_sync.py new file mode 100644 index 00000000..a6bbcfb6 --- /dev/null +++ b/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_sync.py @@ -0,0 +1,247 @@ +"""Integration tests for Oracle sync session store.""" + +import asyncio +from collections.abc import AsyncGenerator +from datetime import timedelta + +import pytest +from pytest_databases.docker.oracle import OracleService + +from sqlspec.adapters.oracledb.config import OracleSyncConfig +from sqlspec.adapters.oracledb.litestar.store import OracleSyncStore + +pytestmark = pytest.mark.xdist_group("oracle") + + +@pytest.fixture +async def oracle_sync_store(oracle_23ai_service: OracleService) -> AsyncGenerator[OracleSyncStore, None]: + """Create Oracle sync store with test database.""" + config = OracleSyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + "min": 1, + "max": 5, + } + ) + store: OracleSyncStore = OracleSyncStore(config, table_name="test_sessions_sync") + try: + await store.create_table() + yield store + await store.delete_all() + finally: + if config.pool_instance: + config.close_pool() + + +async def test_store_create_table(oracle_sync_store: OracleSyncStore) -> None: + """Test table creation.""" + assert oracle_sync_store.table_name == "test_sessions_sync" + + +async def test_store_set_and_get(oracle_sync_store: OracleSyncStore) -> None: + """Test basic set and get operations.""" + test_data = b"test session data" + await oracle_sync_store.set("session_123", test_data) + + result = await oracle_sync_store.get("session_123") + assert result == test_data + + +async def test_store_get_nonexistent(oracle_sync_store: OracleSyncStore) -> None: + """Test getting a non-existent session returns None.""" + result = await oracle_sync_store.get("nonexistent") + assert result is None + + +async def test_store_set_with_string_value(oracle_sync_store: OracleSyncStore) -> None: + """Test setting a string value (should be converted to bytes).""" + await oracle_sync_store.set("session_str", "string data") + + result = await oracle_sync_store.get("session_str") + assert result == b"string data" + + +async def test_store_delete(oracle_sync_store: OracleSyncStore) -> None: + """Test delete operation.""" + await oracle_sync_store.set("session_to_delete", b"data") + + assert await oracle_sync_store.exists("session_to_delete") + + await oracle_sync_store.delete("session_to_delete") + + assert not await oracle_sync_store.exists("session_to_delete") + assert await oracle_sync_store.get("session_to_delete") is None + + +async def test_store_delete_nonexistent(oracle_sync_store: OracleSyncStore) -> None: + """Test deleting a non-existent session is a no-op.""" + await oracle_sync_store.delete("nonexistent") + + +async def test_store_expiration_with_int(oracle_sync_store: OracleSyncStore) -> None: + """Test session expiration with integer seconds.""" + await oracle_sync_store.set("expiring_session", b"data", expires_in=1) + + assert await oracle_sync_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await oracle_sync_store.get("expiring_session") + assert result is None + assert not await oracle_sync_store.exists("expiring_session") + + +async def test_store_expiration_with_timedelta(oracle_sync_store: OracleSyncStore) -> None: + """Test session expiration with timedelta.""" + await oracle_sync_store.set("expiring_session", b"data", expires_in=timedelta(seconds=1)) + + assert await oracle_sync_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await oracle_sync_store.get("expiring_session") + assert result is None + + +async def test_store_no_expiration(oracle_sync_store: OracleSyncStore) -> None: + """Test session without expiration persists.""" + await oracle_sync_store.set("permanent_session", b"data") + + expires_in = await oracle_sync_store.expires_in("permanent_session") + assert expires_in is None + + assert await oracle_sync_store.exists("permanent_session") + + +async def test_store_expires_in(oracle_sync_store: OracleSyncStore) -> None: + """Test expires_in returns correct time.""" + await oracle_sync_store.set("timed_session", b"data", expires_in=10) + + expires_in = await oracle_sync_store.expires_in("timed_session") + assert expires_in is not None + assert 8 <= expires_in <= 10 + + +async def test_store_expires_in_expired(oracle_sync_store: OracleSyncStore) -> None: + """Test expires_in returns 0 for expired session.""" + await oracle_sync_store.set("expired_session", b"data", expires_in=1) + + await asyncio.sleep(1.1) + + expires_in = await oracle_sync_store.expires_in("expired_session") + assert expires_in == 0 + + +async def test_store_cleanup(oracle_sync_store: OracleSyncStore) -> None: + """Test delete_expired removes only expired sessions.""" + await oracle_sync_store.set("active_session", b"data", expires_in=60) + await oracle_sync_store.set("expired_session_1", b"data", expires_in=1) + await oracle_sync_store.set("expired_session_2", b"data", expires_in=1) + await oracle_sync_store.set("permanent_session", b"data") + + await asyncio.sleep(1.1) + + count = await oracle_sync_store.delete_expired() + assert count == 2 + + assert await oracle_sync_store.exists("active_session") + assert await oracle_sync_store.exists("permanent_session") + assert not await oracle_sync_store.exists("expired_session_1") + assert not await oracle_sync_store.exists("expired_session_2") + + +async def test_store_upsert(oracle_sync_store: OracleSyncStore) -> None: + """Test updating existing session (UPSERT).""" + await oracle_sync_store.set("session_upsert", b"original data") + + result = await oracle_sync_store.get("session_upsert") + assert result == b"original data" + + await oracle_sync_store.set("session_upsert", b"updated data") + + result = await oracle_sync_store.get("session_upsert") + assert result == b"updated data" + + +async def test_store_upsert_with_expiration_change(oracle_sync_store: OracleSyncStore) -> None: + """Test updating session expiration.""" + await oracle_sync_store.set("session_exp", b"data", expires_in=60) + + expires_in = await oracle_sync_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in > 50 + + await oracle_sync_store.set("session_exp", b"data", expires_in=10) + + expires_in = await oracle_sync_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in <= 10 + + +async def test_store_renew_for(oracle_sync_store: OracleSyncStore) -> None: + """Test renewing session expiration on get.""" + await oracle_sync_store.set("session_renew", b"data", expires_in=5) + + await asyncio.sleep(3) + + expires_before = await oracle_sync_store.expires_in("session_renew") + assert expires_before is not None + assert expires_before <= 2 + + result = await oracle_sync_store.get("session_renew", renew_for=10) + assert result == b"data" + + expires_after = await oracle_sync_store.expires_in("session_renew") + assert expires_after is not None + assert expires_after > 8 + + +async def test_store_large_data(oracle_sync_store: OracleSyncStore) -> None: + """Test storing large session data (>1MB).""" + large_data = b"x" * (1024 * 1024 + 100) + + await oracle_sync_store.set("large_session", large_data) + + result = await oracle_sync_store.get("large_session") + assert result is not None + assert result == large_data + assert len(result) > 1024 * 1024 + + +async def test_store_delete_all(oracle_sync_store: OracleSyncStore) -> None: + """Test delete_all removes all sessions.""" + await oracle_sync_store.set("session1", b"data1") + await oracle_sync_store.set("session2", b"data2") + await oracle_sync_store.set("session3", b"data3") + + assert await oracle_sync_store.exists("session1") + assert await oracle_sync_store.exists("session2") + assert await oracle_sync_store.exists("session3") + + await oracle_sync_store.delete_all() + + assert not await oracle_sync_store.exists("session1") + assert not await oracle_sync_store.exists("session2") + assert not await oracle_sync_store.exists("session3") + + +async def test_store_exists(oracle_sync_store: OracleSyncStore) -> None: + """Test exists method.""" + assert not await oracle_sync_store.exists("test_session") + + await oracle_sync_store.set("test_session", b"data") + + assert await oracle_sync_store.exists("test_session") + + +async def test_store_context_manager(oracle_sync_store: OracleSyncStore) -> None: + """Test store can be used as async context manager.""" + async with oracle_sync_store: + await oracle_sync_store.set("ctx_session", b"data") + + result = await oracle_sync_store.get("ctx_session") + assert result == b"data" diff --git a/tests/integration/test_adapters/test_oracledb/test_oracle_features.py b/tests/integration/test_adapters/test_oracledb/test_oracle_features.py index c5442a64..e1f95a63 100644 --- a/tests/integration/test_adapters/test_oracledb/test_oracle_features.py +++ b/tests/integration/test_adapters/test_oracledb/test_oracle_features.py @@ -73,7 +73,6 @@ def test_sync_plsql_block_execution(oracle_sync_session: OracleSyncDriver) -> No ) -@pytest.mark.asyncio(loop_scope="function") async def test_async_plsql_procedure_execution(oracle_async_session: OracleAsyncDriver) -> None: """Test creation and execution of PL/SQL stored procedures.""" @@ -197,7 +196,6 @@ def test_sync_oracle_data_types(oracle_sync_session: OracleSyncDriver) -> None: ) -@pytest.mark.asyncio(loop_scope="function") async def test_async_oracle_analytic_functions(oracle_async_session: OracleAsyncDriver) -> None: """Test Oracle's analytic/window functions.""" @@ -298,7 +296,6 @@ def test_oracle_ddl_script_parsing(oracle_sync_session: OracleSyncDriver) -> Non assert "CREATE SEQUENCE" in sql_output -@pytest.mark.asyncio(loop_scope="function") async def test_async_oracle_exception_handling(oracle_async_session: OracleAsyncDriver) -> None: """Test Oracle-specific exception handling in PL/SQL.""" diff --git a/tests/integration/test_adapters/test_oracledb/test_parameter_styles.py b/tests/integration/test_adapters/test_oracledb/test_parameter_styles.py index 8e122372..2693c38a 100644 --- a/tests/integration/test_adapters/test_oracledb/test_parameter_styles.py +++ b/tests/integration/test_adapters/test_oracledb/test_parameter_styles.py @@ -57,7 +57,6 @@ def test_sync_oracle_parameter_styles( ), ], ) -@pytest.mark.asyncio(loop_scope="function") async def test_async_oracle_parameter_styles( oracle_async_session: OracleAsyncDriver, sql: str, params: OracleParamData, expected_rows: list[dict[str, Any]] ) -> None: @@ -112,7 +111,6 @@ def test_sync_oracle_insert_with_named_params(oracle_sync_session: OracleSyncDri ) -@pytest.mark.asyncio(loop_scope="function") async def test_async_oracle_update_with_mixed_params(oracle_async_session: OracleAsyncDriver) -> None: """Test UPDATE operations using mixed parameter styles.""" @@ -203,7 +201,6 @@ def test_sync_oracle_in_clause_with_params(oracle_sync_session: OracleSyncDriver ) -@pytest.mark.asyncio(loop_scope="function") async def test_async_oracle_null_parameter_handling(oracle_async_session: OracleAsyncDriver) -> None: """Test handling of NULL parameters in Oracle.""" @@ -479,7 +476,6 @@ def test_sync_oracle_none_parameters_with_execute_many(oracle_sync_session: Orac ) -@pytest.mark.asyncio(loop_scope="function") async def test_async_oracle_lob_none_parameter_handling(oracle_async_session: OracleAsyncDriver) -> None: """Test Oracle LOB (CLOB/RAW) None parameter handling in async operations.""" @@ -581,7 +577,6 @@ async def test_async_oracle_lob_none_parameter_handling(oracle_async_session: Or ) -@pytest.mark.asyncio(loop_scope="function") async def test_async_oracle_json_none_parameter_handling(oracle_async_session: OracleAsyncDriver) -> None: """Test Oracle JSON column None parameter handling (Oracle 21+ and constraint-based).""" diff --git a/tests/integration/test_adapters/test_psqlpy/test_connection.py b/tests/integration/test_adapters/test_psqlpy/test_connection.py index 588a217f..3b2bd86e 100644 --- a/tests/integration/test_adapters/test_psqlpy/test_connection.py +++ b/tests/integration/test_adapters/test_psqlpy/test_connection.py @@ -15,7 +15,6 @@ pytestmark = pytest.mark.xdist_group("postgres") -@pytest.mark.asyncio async def test_connect_via_pool(psqlpy_config: PsqlpyConfig) -> None: """Test establishing a connection via the pool.""" pool = await psqlpy_config.create_pool() @@ -29,7 +28,6 @@ async def test_connect_via_pool(psqlpy_config: PsqlpyConfig) -> None: assert rows[0]["?column?"] == 1 -@pytest.mark.asyncio async def test_connect_direct(psqlpy_config: PsqlpyConfig) -> None: """Test establishing a connection via the provide_connection context manager.""" @@ -42,7 +40,6 @@ async def test_connect_direct(psqlpy_config: PsqlpyConfig) -> None: assert rows[0]["?column?"] == 1 -@pytest.mark.asyncio async def test_provide_session_context_manager(psqlpy_config: PsqlpyConfig) -> None: """Test the provide_session context manager.""" async with psqlpy_config.provide_session() as driver: @@ -58,7 +55,6 @@ async def test_provide_session_context_manager(psqlpy_config: PsqlpyConfig) -> N assert val == "test" -@pytest.mark.asyncio async def test_connection_error_handling(psqlpy_config: PsqlpyConfig) -> None: """Test connection error handling.""" async with psqlpy_config.provide_session() as driver: @@ -71,7 +67,6 @@ async def test_connection_error_handling(psqlpy_config: PsqlpyConfig) -> None: assert result.data[0]["status"] == "still_working" -@pytest.mark.asyncio async def test_connection_with_core_round_3(psqlpy_config: PsqlpyConfig) -> None: """Test connection integration.""" from sqlspec.core.statement import SQL @@ -86,7 +81,6 @@ async def test_connection_with_core_round_3(psqlpy_config: PsqlpyConfig) -> None assert result.data[0]["test_value"] == "core_test" -@pytest.mark.asyncio async def test_multiple_connections_sequential(psqlpy_config: PsqlpyConfig) -> None: """Test multiple sequential connections.""" @@ -103,7 +97,6 @@ async def test_multiple_connections_sequential(psqlpy_config: PsqlpyConfig) -> N assert result2.data[0]["conn_id"] == "connection2" -@pytest.mark.asyncio async def test_connection_concurrent_access(psqlpy_config: PsqlpyConfig) -> None: """Test concurrent connection access.""" import asyncio diff --git a/tests/integration/test_adapters/test_psqlpy/test_driver.py b/tests/integration/test_adapters/test_psqlpy/test_driver.py index d26ee1c7..ac5f002b 100644 --- a/tests/integration/test_adapters/test_psqlpy/test_driver.py +++ b/tests/integration/test_adapters/test_psqlpy/test_driver.py @@ -26,7 +26,6 @@ pytest.param({"name": "test_name"}, "dict_binds", id="dict_binds"), ], ) -@pytest.mark.asyncio async def test_insert_returning_param_styles(psqlpy_session: PsqlpyDriver, parameters: Any, style: ParamStyle) -> None: """Test insert returning with different parameter styles.""" if style == "tuple_binds": diff --git a/tests/integration/test_adapters/test_psqlpy/test_extensions/__init__.py b/tests/integration/test_adapters/test_psqlpy/test_extensions/__init__.py new file mode 100644 index 00000000..87521e21 --- /dev/null +++ b/tests/integration/test_adapters/test_psqlpy/test_extensions/__init__.py @@ -0,0 +1 @@ +"""Tests for psqlpy extensions.""" diff --git a/tests/integration/test_adapters/test_psqlpy/test_extensions/test_litestar/__init__.py b/tests/integration/test_adapters/test_psqlpy/test_extensions/test_litestar/__init__.py new file mode 100644 index 00000000..56cecf4e --- /dev/null +++ b/tests/integration/test_adapters/test_psqlpy/test_extensions/test_litestar/__init__.py @@ -0,0 +1 @@ +"""Tests for psqlpy litestar integration.""" diff --git a/tests/integration/test_adapters/test_psqlpy/test_extensions/test_litestar/test_store.py b/tests/integration/test_adapters/test_psqlpy/test_extensions/test_litestar/test_store.py new file mode 100644 index 00000000..4f35055a --- /dev/null +++ b/tests/integration/test_adapters/test_psqlpy/test_extensions/test_litestar/test_store.py @@ -0,0 +1,245 @@ +"""Integration tests for Psqlpy session store.""" + +import asyncio +from collections.abc import AsyncGenerator +from datetime import timedelta + +import pytest +from pytest_databases.docker.postgres import PostgresService + +from sqlspec.adapters.psqlpy.config import PsqlpyConfig +from sqlspec.adapters.psqlpy.litestar.store import PsqlpyStore + +pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.psqlpy, pytest.mark.integration] + + +@pytest.fixture +async def psqlpy_store(postgres_service: PostgresService) -> "AsyncGenerator[PsqlpyStore, None]": + """Create Psqlpy store with test database.""" + config = PsqlpyConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "username": postgres_service.user, + "password": postgres_service.password, + "db_name": postgres_service.database, + } + ) + store = PsqlpyStore(config, table_name="test_psqlpy_sessions") + try: + await store.create_table() + yield store + await store.delete_all() + finally: + if config.pool_instance: + await config.close_pool() + + +async def test_store_create_table(psqlpy_store: PsqlpyStore) -> None: + """Test table creation.""" + assert psqlpy_store.table_name == "test_psqlpy_sessions" + + +async def test_store_set_and_get(psqlpy_store: PsqlpyStore) -> None: + """Test basic set and get operations.""" + test_data = b"test session data" + await psqlpy_store.set("session_123", test_data) + + result = await psqlpy_store.get("session_123") + assert result == test_data + + +async def test_store_get_nonexistent(psqlpy_store: PsqlpyStore) -> None: + """Test getting a non-existent session returns None.""" + result = await psqlpy_store.get("nonexistent") + assert result is None + + +async def test_store_set_with_string_value(psqlpy_store: PsqlpyStore) -> None: + """Test setting a string value (should be converted to bytes).""" + await psqlpy_store.set("session_str", "string data") + + result = await psqlpy_store.get("session_str") + assert result == b"string data" + + +async def test_store_delete(psqlpy_store: PsqlpyStore) -> None: + """Test delete operation.""" + await psqlpy_store.set("session_to_delete", b"data") + + assert await psqlpy_store.exists("session_to_delete") + + await psqlpy_store.delete("session_to_delete") + + assert not await psqlpy_store.exists("session_to_delete") + assert await psqlpy_store.get("session_to_delete") is None + + +async def test_store_delete_nonexistent(psqlpy_store: PsqlpyStore) -> None: + """Test deleting a non-existent session is a no-op.""" + await psqlpy_store.delete("nonexistent") + + +async def test_store_expiration_with_int(psqlpy_store: PsqlpyStore) -> None: + """Test session expiration with integer seconds.""" + await psqlpy_store.set("expiring_session", b"data", expires_in=1) + + assert await psqlpy_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await psqlpy_store.get("expiring_session") + assert result is None + assert not await psqlpy_store.exists("expiring_session") + + +async def test_store_expiration_with_timedelta(psqlpy_store: PsqlpyStore) -> None: + """Test session expiration with timedelta.""" + await psqlpy_store.set("expiring_session", b"data", expires_in=timedelta(seconds=1)) + + assert await psqlpy_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await psqlpy_store.get("expiring_session") + assert result is None + + +async def test_store_no_expiration(psqlpy_store: PsqlpyStore) -> None: + """Test session without expiration persists.""" + await psqlpy_store.set("permanent_session", b"data") + + expires_in = await psqlpy_store.expires_in("permanent_session") + assert expires_in is None + + assert await psqlpy_store.exists("permanent_session") + + +async def test_store_expires_in(psqlpy_store: PsqlpyStore) -> None: + """Test expires_in returns correct time.""" + await psqlpy_store.set("timed_session", b"data", expires_in=10) + + expires_in = await psqlpy_store.expires_in("timed_session") + assert expires_in is not None + assert 8 <= expires_in <= 10 + + +async def test_store_expires_in_expired(psqlpy_store: PsqlpyStore) -> None: + """Test expires_in returns 0 for expired session.""" + await psqlpy_store.set("expired_session", b"data", expires_in=1) + + await asyncio.sleep(1.1) + + expires_in = await psqlpy_store.expires_in("expired_session") + assert expires_in == 0 + + +async def test_store_cleanup(psqlpy_store: PsqlpyStore) -> None: + """Test delete_expired removes only expired sessions.""" + await psqlpy_store.set("active_session", b"data", expires_in=60) + await psqlpy_store.set("expired_session_1", b"data", expires_in=1) + await psqlpy_store.set("expired_session_2", b"data", expires_in=1) + await psqlpy_store.set("permanent_session", b"data") + + await asyncio.sleep(1.1) + + count = await psqlpy_store.delete_expired() + assert count == 2 + + assert await psqlpy_store.exists("active_session") + assert await psqlpy_store.exists("permanent_session") + assert not await psqlpy_store.exists("expired_session_1") + assert not await psqlpy_store.exists("expired_session_2") + + +async def test_store_upsert(psqlpy_store: PsqlpyStore) -> None: + """Test updating existing session (UPSERT).""" + await psqlpy_store.set("session_upsert", b"original data") + + result = await psqlpy_store.get("session_upsert") + assert result == b"original data" + + await psqlpy_store.set("session_upsert", b"updated data") + + result = await psqlpy_store.get("session_upsert") + assert result == b"updated data" + + +async def test_store_upsert_with_expiration_change(psqlpy_store: PsqlpyStore) -> None: + """Test updating session expiration.""" + await psqlpy_store.set("session_exp", b"data", expires_in=60) + + expires_in = await psqlpy_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in > 50 + + await psqlpy_store.set("session_exp", b"data", expires_in=10) + + expires_in = await psqlpy_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in <= 10 + + +async def test_store_renew_for(psqlpy_store: PsqlpyStore) -> None: + """Test renewing session expiration on get.""" + await psqlpy_store.set("session_renew", b"data", expires_in=5) + + await asyncio.sleep(3) + + expires_before = await psqlpy_store.expires_in("session_renew") + assert expires_before is not None + assert expires_before <= 2 + + result = await psqlpy_store.get("session_renew", renew_for=10) + assert result == b"data" + + expires_after = await psqlpy_store.expires_in("session_renew") + assert expires_after is not None + assert expires_after > 8 + + +async def test_store_large_data(psqlpy_store: PsqlpyStore) -> None: + """Test storing large session data (>1MB).""" + large_data = b"x" * (1024 * 1024 + 100) + + await psqlpy_store.set("large_session", large_data) + + result = await psqlpy_store.get("large_session") + assert result is not None + assert result == large_data + assert len(result) > 1024 * 1024 + + +async def test_store_delete_all(psqlpy_store: PsqlpyStore) -> None: + """Test delete_all removes all sessions.""" + await psqlpy_store.set("session1", b"data1") + await psqlpy_store.set("session2", b"data2") + await psqlpy_store.set("session3", b"data3") + + assert await psqlpy_store.exists("session1") + assert await psqlpy_store.exists("session2") + assert await psqlpy_store.exists("session3") + + await psqlpy_store.delete_all() + + assert not await psqlpy_store.exists("session1") + assert not await psqlpy_store.exists("session2") + assert not await psqlpy_store.exists("session3") + + +async def test_store_exists(psqlpy_store: PsqlpyStore) -> None: + """Test exists method.""" + assert not await psqlpy_store.exists("test_session") + + await psqlpy_store.set("test_session", b"data") + + assert await psqlpy_store.exists("test_session") + + +async def test_store_context_manager(psqlpy_store: PsqlpyStore) -> None: + """Test store can be used as async context manager.""" + async with psqlpy_store: + await psqlpy_store.set("ctx_session", b"data") + + result = await psqlpy_store.get("ctx_session") + assert result == b"data" diff --git a/tests/integration/test_adapters/test_psycopg/test_async_copy.py b/tests/integration/test_adapters/test_psycopg/test_async_copy.py index 3dc55eef..0e6de6d7 100644 --- a/tests/integration/test_adapters/test_psycopg/test_async_copy.py +++ b/tests/integration/test_adapters/test_psycopg/test_async_copy.py @@ -45,7 +45,6 @@ async def psycopg_async_session(postgres_service: PostgresService) -> AsyncGener await config.close_pool() -@pytest.mark.asyncio async def test_psycopg_async_copy_operations_positional(psycopg_async_session: PsycopgAsyncDriver) -> None: """Test PostgreSQL COPY operations with async psycopg driver using positional parameters.""" @@ -73,7 +72,6 @@ async def test_psycopg_async_copy_operations_positional(psycopg_async_session: P await psycopg_async_session.execute_script("DROP TABLE copy_test_async") -@pytest.mark.asyncio async def test_psycopg_async_copy_operations_keyword(psycopg_async_session: PsycopgAsyncDriver) -> None: """Test PostgreSQL COPY operations with async psycopg driver using keyword parameters.""" @@ -101,7 +99,6 @@ async def test_psycopg_async_copy_operations_keyword(psycopg_async_session: Psyc await psycopg_async_session.execute_script("DROP TABLE copy_test_async_kw") -@pytest.mark.asyncio async def test_psycopg_async_copy_csv_format_positional(psycopg_async_session: PsycopgAsyncDriver) -> None: """Test PostgreSQL COPY operations with CSV format using async driver and positional parameters.""" @@ -128,7 +125,6 @@ async def test_psycopg_async_copy_csv_format_positional(psycopg_async_session: P await psycopg_async_session.execute_script("DROP TABLE copy_csv_async_pos") -@pytest.mark.asyncio async def test_psycopg_async_copy_csv_format_keyword(psycopg_async_session: PsycopgAsyncDriver) -> None: """Test PostgreSQL COPY operations with CSV format using async driver and keyword parameters.""" diff --git a/tests/integration/test_adapters/test_psycopg/test_extensions/__init__.py b/tests/integration/test_adapters/test_psycopg/test_extensions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/test_adapters/test_psycopg/test_extensions/test_litestar/__init__.py b/tests/integration/test_adapters/test_psycopg/test_extensions/test_litestar/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/test_adapters/test_psycopg/test_extensions/test_litestar/test_store_async.py b/tests/integration/test_adapters/test_psycopg/test_extensions/test_litestar/test_store_async.py new file mode 100644 index 00000000..ee2bda34 --- /dev/null +++ b/tests/integration/test_adapters/test_psycopg/test_extensions/test_litestar/test_store_async.py @@ -0,0 +1,248 @@ +"""Integration tests for Psycopg async session store.""" + +import asyncio +from collections.abc import AsyncGenerator +from datetime import timedelta + +import pytest +from pytest_databases.docker.postgres import PostgresService + +from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig +from sqlspec.adapters.psycopg.litestar.store import PsycopgAsyncStore + +pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.psycopg, pytest.mark.integration] + + +@pytest.fixture +async def psycopg_async_store(postgres_service: PostgresService) -> "AsyncGenerator[PsycopgAsyncStore, None]": + """Create Psycopg async store with test database.""" + config = PsycopgAsyncConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "dbname": postgres_service.database, + } + ) + store = PsycopgAsyncStore(config, table_name="test_psycopg_async_sessions") + try: + await store.create_table() + yield store + try: + await store.delete_all() + except Exception: + pass + finally: + if config.pool_instance: + await config.close_pool() + + +async def test_store_create_table(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test table creation.""" + assert psycopg_async_store.table_name == "test_psycopg_async_sessions" + + +async def test_store_set_and_get(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test basic set and get operations.""" + test_data = b"test session data" + await psycopg_async_store.set("session_123", test_data) + + result = await psycopg_async_store.get("session_123") + assert result == test_data + + +async def test_store_get_nonexistent(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test getting a non-existent session returns None.""" + result = await psycopg_async_store.get("nonexistent") + assert result is None + + +async def test_store_set_with_string_value(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test setting a string value (should be converted to bytes).""" + await psycopg_async_store.set("session_str", "string data") + + result = await psycopg_async_store.get("session_str") + assert result == b"string data" + + +async def test_store_delete(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test delete operation.""" + await psycopg_async_store.set("session_to_delete", b"data") + + assert await psycopg_async_store.exists("session_to_delete") + + await psycopg_async_store.delete("session_to_delete") + + assert not await psycopg_async_store.exists("session_to_delete") + assert await psycopg_async_store.get("session_to_delete") is None + + +async def test_store_delete_nonexistent(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test deleting a non-existent session is a no-op.""" + await psycopg_async_store.delete("nonexistent") + + +async def test_store_expiration_with_int(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test session expiration with integer seconds.""" + await psycopg_async_store.set("expiring_session", b"data", expires_in=1) + + assert await psycopg_async_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await psycopg_async_store.get("expiring_session") + assert result is None + assert not await psycopg_async_store.exists("expiring_session") + + +async def test_store_expiration_with_timedelta(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test session expiration with timedelta.""" + await psycopg_async_store.set("expiring_session", b"data", expires_in=timedelta(seconds=1)) + + assert await psycopg_async_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await psycopg_async_store.get("expiring_session") + assert result is None + + +async def test_store_no_expiration(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test session without expiration persists.""" + await psycopg_async_store.set("permanent_session", b"data") + + expires_in = await psycopg_async_store.expires_in("permanent_session") + assert expires_in is None + + assert await psycopg_async_store.exists("permanent_session") + + +async def test_store_expires_in(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test expires_in returns correct time.""" + await psycopg_async_store.set("timed_session", b"data", expires_in=10) + + expires_in = await psycopg_async_store.expires_in("timed_session") + assert expires_in is not None + assert 8 <= expires_in <= 10 + + +async def test_store_expires_in_expired(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test expires_in returns 0 for expired session.""" + await psycopg_async_store.set("expired_session", b"data", expires_in=1) + + await asyncio.sleep(1.1) + + expires_in = await psycopg_async_store.expires_in("expired_session") + assert expires_in == 0 + + +async def test_store_cleanup(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test delete_expired removes only expired sessions.""" + await psycopg_async_store.set("active_session", b"data", expires_in=60) + await psycopg_async_store.set("expired_session_1", b"data", expires_in=1) + await psycopg_async_store.set("expired_session_2", b"data", expires_in=1) + await psycopg_async_store.set("permanent_session", b"data") + + await asyncio.sleep(1.1) + + count = await psycopg_async_store.delete_expired() + assert count == 2 + + assert await psycopg_async_store.exists("active_session") + assert await psycopg_async_store.exists("permanent_session") + assert not await psycopg_async_store.exists("expired_session_1") + assert not await psycopg_async_store.exists("expired_session_2") + + +async def test_store_upsert(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test updating existing session (UPSERT).""" + await psycopg_async_store.set("session_upsert", b"original data") + + result = await psycopg_async_store.get("session_upsert") + assert result == b"original data" + + await psycopg_async_store.set("session_upsert", b"updated data") + + result = await psycopg_async_store.get("session_upsert") + assert result == b"updated data" + + +async def test_store_upsert_with_expiration_change(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test updating session expiration.""" + await psycopg_async_store.set("session_exp", b"data", expires_in=60) + + expires_in = await psycopg_async_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in > 50 + + await psycopg_async_store.set("session_exp", b"data", expires_in=10) + + expires_in = await psycopg_async_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in <= 10 + + +async def test_store_renew_for(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test renewing session expiration on get.""" + await psycopg_async_store.set("session_renew", b"data", expires_in=5) + + await asyncio.sleep(3) + + expires_before = await psycopg_async_store.expires_in("session_renew") + assert expires_before is not None + assert expires_before <= 2 + + result = await psycopg_async_store.get("session_renew", renew_for=10) + assert result == b"data" + + expires_after = await psycopg_async_store.expires_in("session_renew") + assert expires_after is not None + assert expires_after > 8 + + +async def test_store_large_data(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test storing large session data (>1MB).""" + large_data = b"x" * (1024 * 1024 + 100) + + await psycopg_async_store.set("large_session", large_data) + + result = await psycopg_async_store.get("large_session") + assert result is not None + assert result == large_data + assert len(result) > 1024 * 1024 + + +async def test_store_delete_all(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test delete_all removes all sessions.""" + await psycopg_async_store.set("session1", b"data1") + await psycopg_async_store.set("session2", b"data2") + await psycopg_async_store.set("session3", b"data3") + + assert await psycopg_async_store.exists("session1") + assert await psycopg_async_store.exists("session2") + assert await psycopg_async_store.exists("session3") + + await psycopg_async_store.delete_all() + + assert not await psycopg_async_store.exists("session1") + assert not await psycopg_async_store.exists("session2") + assert not await psycopg_async_store.exists("session3") + + +async def test_store_exists(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test exists method.""" + assert not await psycopg_async_store.exists("test_session") + + await psycopg_async_store.set("test_session", b"data") + + assert await psycopg_async_store.exists("test_session") + + +async def test_store_context_manager(psycopg_async_store: PsycopgAsyncStore) -> None: + """Test store can be used as async context manager.""" + async with psycopg_async_store: + await psycopg_async_store.set("ctx_session", b"data") + + result = await psycopg_async_store.get("ctx_session") + assert result == b"data" diff --git a/tests/integration/test_adapters/test_psycopg/test_extensions/test_litestar/test_store_sync.py b/tests/integration/test_adapters/test_psycopg/test_extensions/test_litestar/test_store_sync.py new file mode 100644 index 00000000..1d142b34 --- /dev/null +++ b/tests/integration/test_adapters/test_psycopg/test_extensions/test_litestar/test_store_sync.py @@ -0,0 +1,266 @@ +"""Integration tests for Psycopg sync session store.""" + +import asyncio +from collections.abc import AsyncGenerator +from datetime import timedelta + +import pytest +from pytest_databases.docker.postgres import PostgresService + +from sqlspec.adapters.psycopg.config import PsycopgSyncConfig +from sqlspec.adapters.psycopg.litestar.store import PsycopgSyncStore + +pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.psycopg, pytest.mark.integration] + + +@pytest.fixture +async def psycopg_sync_store(postgres_service: PostgresService) -> AsyncGenerator[PsycopgSyncStore, None]: + """Create Psycopg sync store with test database.""" + config = PsycopgSyncConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "dbname": postgres_service.database, + } + ) + store = PsycopgSyncStore(config, table_name="test_psycopg_sync_sessions") + try: + await store.create_table() + yield store + try: + await store.delete_all() + except Exception: + pass + finally: + if config.pool_instance: + config.close_pool() + + +async def test_store_create_table(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test table creation.""" + assert psycopg_sync_store.table_name == "test_psycopg_sync_sessions" + + +async def test_store_set_and_get(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test basic set and get operations.""" + test_data = b"test session data" + await psycopg_sync_store.set("session_123", test_data) + + result = await psycopg_sync_store.get("session_123") + assert result == test_data + + +async def test_store_get_nonexistent(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test getting a non-existent session returns None.""" + result = await psycopg_sync_store.get("nonexistent") + assert result is None + + +async def test_store_set_with_string_value(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test setting a string value (should be converted to bytes).""" + await psycopg_sync_store.set("session_str", "string data") + + result = await psycopg_sync_store.get("session_str") + assert result == b"string data" + + +async def test_store_delete(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test delete operation.""" + await psycopg_sync_store.set("session_to_delete", b"data") + + assert await psycopg_sync_store.exists("session_to_delete") + + await psycopg_sync_store.delete("session_to_delete") + + assert not await psycopg_sync_store.exists("session_to_delete") + assert await psycopg_sync_store.get("session_to_delete") is None + + +async def test_store_delete_nonexistent(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test deleting a non-existent session is a no-op.""" + await psycopg_sync_store.delete("nonexistent") + + +async def test_store_expiration_with_int(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test session expiration with integer seconds.""" + await psycopg_sync_store.set("expiring_session", b"data", expires_in=1) + + assert await psycopg_sync_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await psycopg_sync_store.get("expiring_session") + assert result is None + assert not await psycopg_sync_store.exists("expiring_session") + + +async def test_store_expiration_with_timedelta(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test session expiration with timedelta.""" + await psycopg_sync_store.set("expiring_session", b"data", expires_in=timedelta(seconds=1)) + + assert await psycopg_sync_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await psycopg_sync_store.get("expiring_session") + assert result is None + + +async def test_store_no_expiration(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test session without expiration persists.""" + await psycopg_sync_store.set("permanent_session", b"data") + + expires_in = await psycopg_sync_store.expires_in("permanent_session") + assert expires_in is None + + assert await psycopg_sync_store.exists("permanent_session") + + +async def test_store_expires_in(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test expires_in returns correct time.""" + await psycopg_sync_store.set("timed_session", b"data", expires_in=10) + + expires_in = await psycopg_sync_store.expires_in("timed_session") + assert expires_in is not None + assert 8 <= expires_in <= 10 + + +async def test_store_expires_in_expired(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test expires_in returns 0 for expired session.""" + await psycopg_sync_store.set("expired_session", b"data", expires_in=1) + + await asyncio.sleep(1.1) + + expires_in = await psycopg_sync_store.expires_in("expired_session") + assert expires_in == 0 + + +async def test_store_cleanup(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test delete_expired removes only expired sessions.""" + await psycopg_sync_store.set("active_session", b"data", expires_in=60) + await psycopg_sync_store.set("expired_session_1", b"data", expires_in=1) + await psycopg_sync_store.set("expired_session_2", b"data", expires_in=1) + await psycopg_sync_store.set("permanent_session", b"data") + + await asyncio.sleep(1.1) + + count = await psycopg_sync_store.delete_expired() + assert count == 2 + + assert await psycopg_sync_store.exists("active_session") + assert await psycopg_sync_store.exists("permanent_session") + assert not await psycopg_sync_store.exists("expired_session_1") + assert not await psycopg_sync_store.exists("expired_session_2") + + +async def test_store_upsert(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test updating existing session (UPSERT).""" + await psycopg_sync_store.set("session_upsert", b"original data") + + result = await psycopg_sync_store.get("session_upsert") + assert result == b"original data" + + await psycopg_sync_store.set("session_upsert", b"updated data") + + result = await psycopg_sync_store.get("session_upsert") + assert result == b"updated data" + + +async def test_store_upsert_with_expiration_change(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test updating session expiration.""" + await psycopg_sync_store.set("session_exp", b"data", expires_in=60) + + expires_in = await psycopg_sync_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in > 50 + + await psycopg_sync_store.set("session_exp", b"data", expires_in=10) + + expires_in = await psycopg_sync_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in <= 10 + + +async def test_store_renew_for(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test renewing session expiration on get.""" + await psycopg_sync_store.set("session_renew", b"data", expires_in=5) + + await asyncio.sleep(3) + + expires_before = await psycopg_sync_store.expires_in("session_renew") + assert expires_before is not None + assert expires_before <= 2 + + result = await psycopg_sync_store.get("session_renew", renew_for=10) + assert result == b"data" + + expires_after = await psycopg_sync_store.expires_in("session_renew") + assert expires_after is not None + assert expires_after > 8 + + +async def test_store_large_data(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test storing large session data (>1MB).""" + large_data = b"x" * (1024 * 1024 + 100) + + await psycopg_sync_store.set("large_session", large_data) + + result = await psycopg_sync_store.get("large_session") + assert result is not None + assert result == large_data + assert len(result) > 1024 * 1024 + + +async def test_store_delete_all(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test delete_all removes all sessions.""" + await psycopg_sync_store.set("session1", b"data1") + await psycopg_sync_store.set("session2", b"data2") + await psycopg_sync_store.set("session3", b"data3") + + assert await psycopg_sync_store.exists("session1") + assert await psycopg_sync_store.exists("session2") + assert await psycopg_sync_store.exists("session3") + + await psycopg_sync_store.delete_all() + + assert not await psycopg_sync_store.exists("session1") + assert not await psycopg_sync_store.exists("session2") + assert not await psycopg_sync_store.exists("session3") + + +async def test_store_exists(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test exists method.""" + assert not await psycopg_sync_store.exists("test_session") + + await psycopg_sync_store.set("test_session", b"data") + + assert await psycopg_sync_store.exists("test_session") + + +async def test_store_context_manager(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test store can be used as async context manager.""" + async with psycopg_sync_store: + await psycopg_sync_store.set("ctx_session", b"data") + + result = await psycopg_sync_store.get("ctx_session") + assert result == b"data" + + +async def test_sync_to_thread_concurrency(psycopg_sync_store: PsycopgSyncStore) -> None: + """Test concurrent access via sync_to_thread wrapper. + + PostgreSQL handles concurrent writes well, so we test concurrent + writes and reads which is a typical session store pattern. + """ + for i in range(10): + await psycopg_sync_store.set(f"session_{i}", f"data_{i}".encode()) + + async def read_session(session_id: int) -> "bytes | None": + return await psycopg_sync_store.get(f"session_{session_id}") + + results = await asyncio.gather(*[read_session(i) for i in range(10)]) + + for i, result in enumerate(results): + assert result == f"data_{i}".encode() diff --git a/tests/integration/test_adapters/test_sqlite/test_extensions/__init__.py b/tests/integration/test_adapters/test_sqlite/test_extensions/__init__.py new file mode 100644 index 00000000..e0d6e068 --- /dev/null +++ b/tests/integration/test_adapters/test_sqlite/test_extensions/__init__.py @@ -0,0 +1 @@ +"""SQLite extension integration tests.""" diff --git a/tests/integration/test_adapters/test_sqlite/test_extensions/test_litestar/__init__.py b/tests/integration/test_adapters/test_sqlite/test_extensions/test_litestar/__init__.py new file mode 100644 index 00000000..5211816b --- /dev/null +++ b/tests/integration/test_adapters/test_sqlite/test_extensions/test_litestar/__init__.py @@ -0,0 +1 @@ +"""Integration tests for SQLite Litestar extensions.""" diff --git a/tests/integration/test_adapters/test_sqlite/test_extensions/test_litestar/test_store.py b/tests/integration/test_adapters/test_sqlite/test_extensions/test_litestar/test_store.py new file mode 100644 index 00000000..52de783a --- /dev/null +++ b/tests/integration/test_adapters/test_sqlite/test_extensions/test_litestar/test_store.py @@ -0,0 +1,250 @@ +"""Integration tests for SQLite sync session store.""" + +import asyncio +from collections.abc import AsyncGenerator +from datetime import timedelta + +import pytest + +from sqlspec.adapters.sqlite.config import SqliteConfig +from sqlspec.adapters.sqlite.litestar.store import SQLiteStore + +pytestmark = [pytest.mark.sqlite, pytest.mark.integration] + + +@pytest.fixture +async def sqlite_store() -> AsyncGenerator[SQLiteStore, None]: + """Create SQLite store with shared in-memory database.""" + config = SqliteConfig(pool_config={"database": "file:test_sessions_mem?mode=memory&cache=shared", "uri": True}) + store = SQLiteStore(config, table_name="test_sessions") + await store.create_table() + yield store + await store.delete_all() + + +async def test_store_create_table(sqlite_store: SQLiteStore) -> None: + """Test table creation.""" + assert sqlite_store.table_name == "test_sessions" + + +async def test_store_set_and_get(sqlite_store: SQLiteStore) -> None: + """Test basic set and get operations.""" + test_data = b"test session data" + await sqlite_store.set("session_123", test_data) + + result = await sqlite_store.get("session_123") + assert result == test_data + + +async def test_store_get_nonexistent(sqlite_store: SQLiteStore) -> None: + """Test getting a non-existent session returns None.""" + result = await sqlite_store.get("nonexistent") + assert result is None + + +async def test_store_set_with_string_value(sqlite_store: SQLiteStore) -> None: + """Test setting a string value (should be converted to bytes).""" + await sqlite_store.set("session_str", "string data") + + result = await sqlite_store.get("session_str") + assert result == b"string data" + + +async def test_store_delete(sqlite_store: SQLiteStore) -> None: + """Test delete operation.""" + await sqlite_store.set("session_to_delete", b"data") + + assert await sqlite_store.exists("session_to_delete") + + await sqlite_store.delete("session_to_delete") + + assert not await sqlite_store.exists("session_to_delete") + assert await sqlite_store.get("session_to_delete") is None + + +async def test_store_delete_nonexistent(sqlite_store: SQLiteStore) -> None: + """Test deleting a non-existent session is a no-op.""" + await sqlite_store.delete("nonexistent") + + +async def test_store_expiration_with_int(sqlite_store: SQLiteStore) -> None: + """Test session expiration with integer seconds.""" + await sqlite_store.set("expiring_session", b"data", expires_in=1) + + assert await sqlite_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await sqlite_store.get("expiring_session") + assert result is None + assert not await sqlite_store.exists("expiring_session") + + +async def test_store_expiration_with_timedelta(sqlite_store: SQLiteStore) -> None: + """Test session expiration with timedelta.""" + await sqlite_store.set("expiring_session", b"data", expires_in=timedelta(seconds=1)) + + assert await sqlite_store.exists("expiring_session") + + await asyncio.sleep(1.1) + + result = await sqlite_store.get("expiring_session") + assert result is None + + +async def test_store_no_expiration(sqlite_store: SQLiteStore) -> None: + """Test session without expiration persists.""" + await sqlite_store.set("permanent_session", b"data") + + expires_in = await sqlite_store.expires_in("permanent_session") + assert expires_in is None + + assert await sqlite_store.exists("permanent_session") + + +async def test_store_expires_in(sqlite_store: SQLiteStore) -> None: + """Test expires_in returns correct time.""" + await sqlite_store.set("timed_session", b"data", expires_in=10) + + expires_in = await sqlite_store.expires_in("timed_session") + assert expires_in is not None + assert 8 <= expires_in <= 10 + + +async def test_store_expires_in_expired(sqlite_store: SQLiteStore) -> None: + """Test expires_in returns 0 for expired session.""" + await sqlite_store.set("expired_session", b"data", expires_in=1) + + await asyncio.sleep(1.1) + + expires_in = await sqlite_store.expires_in("expired_session") + assert expires_in == 0 + + +async def test_store_cleanup(sqlite_store: SQLiteStore) -> None: + """Test delete_expired removes only expired sessions.""" + await sqlite_store.set("active_session", b"data", expires_in=60) + await sqlite_store.set("expired_session_1", b"data", expires_in=1) + await sqlite_store.set("expired_session_2", b"data", expires_in=1) + await sqlite_store.set("permanent_session", b"data") + + await asyncio.sleep(1.1) + + count = await sqlite_store.delete_expired() + assert count == 2 + + assert await sqlite_store.exists("active_session") + assert await sqlite_store.exists("permanent_session") + assert not await sqlite_store.exists("expired_session_1") + assert not await sqlite_store.exists("expired_session_2") + + +async def test_store_upsert(sqlite_store: SQLiteStore) -> None: + """Test updating existing session (UPSERT).""" + await sqlite_store.set("session_upsert", b"original data") + + result = await sqlite_store.get("session_upsert") + assert result == b"original data" + + await sqlite_store.set("session_upsert", b"updated data") + + result = await sqlite_store.get("session_upsert") + assert result == b"updated data" + + +async def test_store_upsert_with_expiration_change(sqlite_store: SQLiteStore) -> None: + """Test updating session expiration.""" + await sqlite_store.set("session_exp", b"data", expires_in=60) + + expires_in = await sqlite_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in > 50 + + await sqlite_store.set("session_exp", b"data", expires_in=10) + + expires_in = await sqlite_store.expires_in("session_exp") + assert expires_in is not None + assert expires_in <= 10 + + +async def test_store_renew_for(sqlite_store: SQLiteStore) -> None: + """Test renewing session expiration on get.""" + await sqlite_store.set("session_renew", b"data", expires_in=5) + + await asyncio.sleep(3) + + expires_before = await sqlite_store.expires_in("session_renew") + assert expires_before is not None + assert expires_before <= 2 + + result = await sqlite_store.get("session_renew", renew_for=10) + assert result == b"data" + + expires_after = await sqlite_store.expires_in("session_renew") + assert expires_after is not None + assert expires_after > 8 + + +async def test_store_large_data(sqlite_store: SQLiteStore) -> None: + """Test storing large session data (>1MB).""" + large_data = b"x" * (1024 * 1024 + 100) + + await sqlite_store.set("large_session", large_data) + + result = await sqlite_store.get("large_session") + assert result == large_data + assert result is not None + assert len(result) > 1024 * 1024 + + +async def test_store_delete_all(sqlite_store: SQLiteStore) -> None: + """Test delete_all removes all sessions.""" + await sqlite_store.set("session1", b"data1") + await sqlite_store.set("session2", b"data2") + await sqlite_store.set("session3", b"data3") + + assert await sqlite_store.exists("session1") + assert await sqlite_store.exists("session2") + assert await sqlite_store.exists("session3") + + await sqlite_store.delete_all() + + assert not await sqlite_store.exists("session1") + assert not await sqlite_store.exists("session2") + assert not await sqlite_store.exists("session3") + + +async def test_store_exists(sqlite_store: SQLiteStore) -> None: + """Test exists method.""" + assert not await sqlite_store.exists("test_session") + + await sqlite_store.set("test_session", b"data") + + assert await sqlite_store.exists("test_session") + + +async def test_store_context_manager(sqlite_store: SQLiteStore) -> None: + """Test store can be used as async context manager.""" + async with sqlite_store: + await sqlite_store.set("ctx_session", b"data") + + result = await sqlite_store.get("ctx_session") + assert result == b"data" + + +async def test_sync_to_thread_concurrency(sqlite_store: SQLiteStore) -> None: + """Test concurrent access via sync_to_thread wrapper. + + SQLite has write serialization, so we test sequential writes + followed by concurrent reads which is the typical session store pattern. + """ + for i in range(10): + await sqlite_store.set(f"session_{i}", f"data_{i}".encode()) + + async def read_session(session_id: int) -> "bytes | None": + return await sqlite_store.get(f"session_{session_id}") + + results = await asyncio.gather(*[read_session(i) for i in range(10)]) + + for i, result in enumerate(results): + assert result == f"data_{i}".encode() diff --git a/tests/unit/test_adapters/test_async_adapters.py b/tests/unit/test_adapters/test_async_adapters.py index ed5543e4..aaa21060 100644 --- a/tests/unit/test_adapters/test_async_adapters.py +++ b/tests/unit/test_adapters/test_async_adapters.py @@ -18,7 +18,6 @@ __all__ = () -@pytest.mark.asyncio async def test_async_driver_initialization(mock_async_connection: MockAsyncConnection) -> None: """Test basic async driver initialization.""" driver = MockAsyncDriver(mock_async_connection) @@ -29,7 +28,6 @@ async def test_async_driver_initialization(mock_async_connection: MockAsyncConne assert driver.statement_config.parameter_config.default_parameter_style == ParameterStyle.QMARK -@pytest.mark.asyncio async def test_async_driver_with_custom_config(mock_async_connection: MockAsyncConnection) -> None: """Test async driver initialization with custom statement config.""" custom_config = StatementConfig( @@ -44,7 +42,6 @@ async def test_async_driver_with_custom_config(mock_async_connection: MockAsyncC assert driver.statement_config.parameter_config.default_parameter_style == ParameterStyle.NUMERIC -@pytest.mark.asyncio async def test_async_driver_with_cursor(mock_async_driver: MockAsyncDriver) -> None: """Test async cursor context manager functionality.""" async with mock_async_driver.with_cursor(mock_async_driver.connection) as cursor: @@ -54,7 +51,6 @@ async def test_async_driver_with_cursor(mock_async_driver: MockAsyncDriver) -> N assert cursor.connection is mock_async_driver.connection -@pytest.mark.asyncio async def test_async_driver_database_exception_handling(mock_async_driver: MockAsyncDriver) -> None: """Test async database exception handling context manager.""" async with mock_async_driver.handle_database_exceptions(): @@ -65,7 +61,6 @@ async def test_async_driver_database_exception_handling(mock_async_driver: MockA raise ValueError("Test async error") -@pytest.mark.asyncio async def test_async_driver_execute_statement_select(mock_async_driver: MockAsyncDriver) -> None: """Test async _execute_statement method with SELECT query.""" statement = SQL("SELECT id, name FROM users", statement_config=mock_async_driver.statement_config) @@ -81,7 +76,6 @@ async def test_async_driver_execute_statement_select(mock_async_driver: MockAsyn assert result.data_row_count == 2 -@pytest.mark.asyncio async def test_async_driver_execute_statement_insert(mock_async_driver: MockAsyncDriver) -> None: """Test async _execute_statement method with INSERT query.""" statement = SQL("INSERT INTO users (name) VALUES (?)", "test", statement_config=mock_async_driver.statement_config) @@ -96,7 +90,6 @@ async def test_async_driver_execute_statement_insert(mock_async_driver: MockAsyn assert result.selected_data is None -@pytest.mark.asyncio async def test_async_driver_execute_many(mock_async_driver: MockAsyncDriver) -> None: """Test async _execute_many method.""" statement = SQL( @@ -115,7 +108,6 @@ async def test_async_driver_execute_many(mock_async_driver: MockAsyncDriver) -> assert mock_async_driver.connection.execute_many_count == 1 -@pytest.mark.asyncio async def test_async_driver_execute_many_no_parameters(mock_async_driver: MockAsyncDriver) -> None: """Test async _execute_many method fails without parameters.""" statement = SQL( @@ -126,7 +118,6 @@ async def test_async_driver_execute_many_no_parameters(mock_async_driver: MockAs await mock_async_driver._execute_many(cursor, statement) -@pytest.mark.asyncio async def test_async_driver_execute_script(mock_async_driver: MockAsyncDriver) -> None: """Test async _execute_script method.""" script = """ @@ -145,7 +136,6 @@ async def test_async_driver_execute_script(mock_async_driver: MockAsyncDriver) - assert result.successful_statements == 3 -@pytest.mark.asyncio async def test_async_driver_dispatch_statement_execution_select(mock_async_driver: MockAsyncDriver) -> None: """Test async dispatch_statement_execution with SELECT statement.""" statement = SQL("SELECT * FROM users", statement_config=mock_async_driver.statement_config) @@ -159,7 +149,6 @@ async def test_async_driver_dispatch_statement_execution_select(mock_async_drive assert result.get_data()[0]["name"] == "test" -@pytest.mark.asyncio async def test_async_driver_dispatch_statement_execution_insert(mock_async_driver: MockAsyncDriver) -> None: """Test async dispatch_statement_execution with INSERT statement.""" statement = SQL("INSERT INTO users (name) VALUES (?)", "test", statement_config=mock_async_driver.statement_config) @@ -172,7 +161,6 @@ async def test_async_driver_dispatch_statement_execution_insert(mock_async_drive assert len(result.get_data()) == 0 -@pytest.mark.asyncio async def test_async_driver_dispatch_statement_execution_script(mock_async_driver: MockAsyncDriver) -> None: """Test async dispatch_statement_execution with script.""" script = "INSERT INTO users (name) VALUES ('alice'); INSERT INTO users (name) VALUES ('bob');" @@ -186,7 +174,6 @@ async def test_async_driver_dispatch_statement_execution_script(mock_async_drive assert result.successful_statements == 2 -@pytest.mark.asyncio async def test_async_driver_dispatch_statement_execution_many(mock_async_driver: MockAsyncDriver) -> None: """Test async dispatch_statement_execution with execute_many.""" statement = SQL( @@ -203,7 +190,6 @@ async def test_async_driver_dispatch_statement_execution_many(mock_async_driver: assert result.rows_affected == 2 -@pytest.mark.asyncio async def test_async_driver_transaction_management(mock_async_driver: MockAsyncDriver) -> None: """Test async transaction management methods.""" connection = mock_async_driver.connection @@ -220,7 +206,6 @@ async def test_async_driver_transaction_management(mock_async_driver: MockAsyncD assert connection.in_transaction is False -@pytest.mark.asyncio async def test_async_driver_execute_method(mock_async_driver: MockAsyncDriver) -> None: """Test high-level async execute method.""" result = await mock_async_driver.execute("SELECT * FROM users WHERE id = ?", 1) @@ -230,7 +215,6 @@ async def test_async_driver_execute_method(mock_async_driver: MockAsyncDriver) - assert len(result.get_data()) == 2 -@pytest.mark.asyncio async def test_async_driver_execute_many_method(mock_async_driver: MockAsyncDriver) -> None: """Test high-level async execute_many method.""" parameters = [["alice"], ["bob"], ["charlie"]] @@ -241,7 +225,6 @@ async def test_async_driver_execute_many_method(mock_async_driver: MockAsyncDriv assert result.rows_affected == 3 -@pytest.mark.asyncio async def test_async_driver_execute_script_method(mock_async_driver: MockAsyncDriver) -> None: """Test high-level async execute_script method.""" script = "INSERT INTO users (name) VALUES ('alice'); UPDATE users SET active = 1;" @@ -253,14 +236,12 @@ async def test_async_driver_execute_script_method(mock_async_driver: MockAsyncDr assert result.successful_statements == 2 -@pytest.mark.asyncio async def test_async_driver_select_one(mock_async_driver: MockAsyncDriver) -> None: """Test async select_one method - expects error when multiple rows returned.""" with pytest.raises(ValueError, match="Expected exactly one row, found 2"): await mock_async_driver.select_one("SELECT * FROM users WHERE id = ?", 1) -@pytest.mark.asyncio async def test_async_driver_select_one_no_results(mock_async_driver: MockAsyncDriver) -> None: """Test async select_one method with no results.""" @@ -273,7 +254,6 @@ async def test_async_driver_select_one_no_results(mock_async_driver: MockAsyncDr await mock_async_driver.select_one("SELECT * FROM users WHERE id = ?", 999) -@pytest.mark.asyncio async def test_async_driver_select_one_multiple_results(mock_async_driver: MockAsyncDriver) -> None: """Test async select_one method with multiple results.""" @@ -286,14 +266,12 @@ async def test_async_driver_select_one_multiple_results(mock_async_driver: MockA await mock_async_driver.select_one("SELECT * FROM users") -@pytest.mark.asyncio async def test_async_driver_select_one_or_none(mock_async_driver: MockAsyncDriver) -> None: """Test async select_one_or_none method - expects error when multiple rows returned.""" with pytest.raises(ValueError, match="Expected at most one row, found 2"): await mock_async_driver.select_one_or_none("SELECT * FROM users WHERE id = ?", 1) -@pytest.mark.asyncio async def test_async_driver_select_one_or_none_no_results(mock_async_driver: MockAsyncDriver) -> None: """Test async select_one_or_none method with no results.""" with patch.object(mock_async_driver, "execute", new_callable=AsyncMock) as mock_execute: @@ -305,7 +283,6 @@ async def test_async_driver_select_one_or_none_no_results(mock_async_driver: Moc assert result is None -@pytest.mark.asyncio async def test_async_driver_select_one_or_none_multiple_results(mock_async_driver: MockAsyncDriver) -> None: """Test async select_one_or_none method with multiple results.""" with patch.object(mock_async_driver, "execute", new_callable=AsyncMock) as mock_execute: @@ -317,7 +294,6 @@ async def test_async_driver_select_one_or_none_multiple_results(mock_async_drive await mock_async_driver.select_one_or_none("SELECT * FROM users") -@pytest.mark.asyncio async def test_async_driver_select(mock_async_driver: MockAsyncDriver) -> None: """Test async select method.""" result: list[dict[str, Any]] = await mock_async_driver.select("SELECT * FROM users") @@ -328,7 +304,6 @@ async def test_async_driver_select(mock_async_driver: MockAsyncDriver) -> None: assert result[1]["id"] == 2 -@pytest.mark.asyncio async def test_async_driver_select_value(mock_async_driver: MockAsyncDriver) -> None: """Test async select_value method.""" @@ -341,7 +316,6 @@ async def test_async_driver_select_value(mock_async_driver: MockAsyncDriver) -> assert result == 42 -@pytest.mark.asyncio async def test_async_driver_select_value_no_results(mock_async_driver: MockAsyncDriver) -> None: """Test async select_value method with no results.""" with patch.object(mock_async_driver, "execute", new_callable=AsyncMock) as mock_execute: @@ -353,14 +327,12 @@ async def test_async_driver_select_value_no_results(mock_async_driver: MockAsync await mock_async_driver.select_value("SELECT COUNT(*) FROM users WHERE id = 999") -@pytest.mark.asyncio async def test_async_driver_select_value_or_none(mock_async_driver: MockAsyncDriver) -> None: """Test async select_value_or_none method - expects error when multiple rows returned.""" with pytest.raises(ValueError, match="Expected at most one row, found 2"): await mock_async_driver.select_value_or_none("SELECT * FROM users WHERE id = ?", 1) -@pytest.mark.asyncio async def test_async_driver_select_value_or_none_no_results(mock_async_driver: MockAsyncDriver) -> None: """Test async select_value_or_none method with no results.""" with patch.object(mock_async_driver, "execute", new_callable=AsyncMock) as mock_execute: @@ -372,7 +344,6 @@ async def test_async_driver_select_value_or_none_no_results(mock_async_driver: M assert result is None -@pytest.mark.asyncio @pytest.mark.parametrize( "parameter_style,expected_style", [ @@ -412,7 +383,6 @@ async def test_async_driver_parameter_styles( assert isinstance(result, SQLResult) -@pytest.mark.asyncio @pytest.mark.parametrize("dialect", ["sqlite", "postgres", "mysql"]) async def test_async_driver_different_dialects(mock_async_connection: MockAsyncConnection, dialect: str) -> None: """Test async driver works with different SQL dialects.""" @@ -430,7 +400,6 @@ async def test_async_driver_different_dialects(mock_async_connection: MockAsyncC assert isinstance(result, SQLResult) -@pytest.mark.asyncio async def test_async_driver_create_execution_result(mock_async_driver: MockAsyncDriver) -> None: """Test async create_execution_result method.""" cursor = mock_async_driver.with_cursor(mock_async_driver.connection) @@ -456,7 +425,6 @@ async def test_async_driver_create_execution_result(mock_async_driver: MockAsync assert result.successful_statements == 3 -@pytest.mark.asyncio async def test_async_driver_build_statement_result(mock_async_driver: MockAsyncDriver) -> None: """Test async build_statement_result method.""" statement = SQL("SELECT * FROM users", statement_config=mock_async_driver.statement_config) @@ -485,7 +453,6 @@ async def test_async_driver_build_statement_result(mock_async_driver: MockAsyncD assert script_sql_result.successful_statements == 1 -@pytest.mark.asyncio async def test_async_driver_special_handling_integration(mock_async_driver: MockAsyncDriver) -> None: """Test that async _try_special_handling is called during dispatch.""" statement = SQL("SELECT * FROM users", statement_config=mock_async_driver.statement_config) @@ -499,7 +466,6 @@ async def test_async_driver_special_handling_integration(mock_async_driver: Mock mock_special.assert_called_once() -@pytest.mark.asyncio async def test_async_driver_error_handling_in_dispatch(mock_async_driver: MockAsyncDriver) -> None: """Test error handling during async statement dispatch.""" statement = SQL("SELECT * FROM users", statement_config=mock_async_driver.statement_config) @@ -511,7 +477,6 @@ async def test_async_driver_error_handling_in_dispatch(mock_async_driver: MockAs await mock_async_driver.dispatch_statement_execution(statement, mock_async_driver.connection) -@pytest.mark.asyncio async def test_async_driver_statement_processing_integration(mock_async_driver: MockAsyncDriver) -> None: """Test async driver statement processing integration.""" statement = SQL("SELECT * FROM users WHERE active = ?", True, statement_config=mock_async_driver.statement_config) @@ -523,7 +488,6 @@ async def test_async_driver_statement_processing_integration(mock_async_driver: assert mock_compile.called or statement.sql == "SELECT * FROM test" -@pytest.mark.asyncio async def test_async_driver_context_manager_integration(mock_async_driver: MockAsyncDriver) -> None: """Test async context manager integration during execution.""" statement = SQL("SELECT * FROM users", statement_config=mock_async_driver.statement_config) @@ -543,7 +507,6 @@ async def test_async_driver_context_manager_integration(mock_async_driver: MockA mock_handle_exceptions.assert_called_once() -@pytest.mark.asyncio async def test_async_driver_resource_cleanup(mock_async_driver: MockAsyncDriver) -> None: """Test async resource cleanup during execution.""" connection = mock_async_driver.connection @@ -555,7 +518,6 @@ async def test_async_driver_resource_cleanup(mock_async_driver: MockAsyncDriver) assert cursor.closed is True -@pytest.mark.asyncio async def test_async_driver_concurrent_execution(mock_async_connection: MockAsyncConnection) -> None: """Test concurrent execution capability of async driver.""" import asyncio @@ -574,7 +536,6 @@ async def execute_query(query_id: int) -> SQLResult: assert result.operation_type == "SELECT" -@pytest.mark.asyncio async def test_async_driver_with_transaction_context(mock_async_driver: MockAsyncDriver) -> None: """Test async driver transaction context usage.""" connection = mock_async_driver.connection diff --git a/tests/unit/test_adapters/test_extension_config.py b/tests/unit/test_adapters/test_extension_config.py new file mode 100644 index 00000000..f032c7f7 --- /dev/null +++ b/tests/unit/test_adapters/test_extension_config.py @@ -0,0 +1,194 @@ +"""Test extension_config parameter support across all adapters.""" + +from typing import Any + +import pytest + +from sqlspec.adapters.adbc import AdbcConfig +from sqlspec.adapters.aiosqlite import AiosqliteConfig +from sqlspec.adapters.asyncmy import AsyncmyConfig +from sqlspec.adapters.asyncpg import AsyncpgConfig +from sqlspec.adapters.bigquery import BigQueryConfig +from sqlspec.adapters.duckdb import DuckDBConfig +from sqlspec.adapters.oracledb import OracleAsyncConfig, OracleSyncConfig +from sqlspec.adapters.psqlpy import PsqlpyConfig +from sqlspec.adapters.psycopg import PsycopgAsyncConfig, PsycopgSyncConfig +from sqlspec.adapters.sqlite import SqliteConfig + + +def test_sqlite_extension_config() -> None: + """Test SqliteConfig accepts and stores extension_config.""" + extension_config = {"litestar": {"session_key": "custom_session", "commit_mode": "manual"}} + + config = SqliteConfig(pool_config={"database": ":memory:"}, extension_config=extension_config) + + assert config.extension_config == extension_config + assert config.extension_config["litestar"]["session_key"] == "custom_session" + + +def test_aiosqlite_extension_config() -> None: + """Test AiosqliteConfig accepts and stores extension_config.""" + extension_config = {"litestar": {"pool_key": "db_pool", "enable_correlation_middleware": False}} + + config = AiosqliteConfig(pool_config={"database": ":memory:"}, extension_config=extension_config) + + assert config.extension_config == extension_config + assert config.extension_config["litestar"]["pool_key"] == "db_pool" + + +def test_duckdb_extension_config() -> None: + """Test DuckDBConfig accepts and stores extension_config.""" + extension_config = {"litestar": {"connection_key": "duckdb_conn"}} + + config = DuckDBConfig(pool_config={"database": ":memory:"}, extension_config=extension_config) + + assert config.extension_config == extension_config + + +def test_asyncpg_extension_config() -> None: + """Test AsyncpgConfig accepts and stores extension_config.""" + extension_config = {"litestar": {"commit_mode": "autocommit"}} + + config = AsyncpgConfig(pool_config={"host": "localhost", "database": "test"}, extension_config=extension_config) + + assert config.extension_config == extension_config + + +def test_psycopg_sync_extension_config() -> None: + """Test PsycopgSyncConfig accepts and stores extension_config.""" + extension_config = {"litestar": {"session_key": "psycopg_session"}} + + config = PsycopgSyncConfig(pool_config={"host": "localhost", "dbname": "test"}, extension_config=extension_config) + + assert config.extension_config == extension_config + + +def test_psycopg_async_extension_config() -> None: + """Test PsycopgAsyncConfig accepts and stores extension_config.""" + extension_config = {"litestar": {"extra_commit_statuses": {201, 202}}} + + config = PsycopgAsyncConfig(pool_config={"host": "localhost", "dbname": "test"}, extension_config=extension_config) + + assert config.extension_config == extension_config + + +def test_asyncmy_extension_config() -> None: + """Test AsyncmyConfig accepts and stores extension_config.""" + extension_config = {"litestar": {"commit_mode": "autocommit_include_redirect"}} + + config = AsyncmyConfig(pool_config={"host": "localhost", "database": "test"}, extension_config=extension_config) + + assert config.extension_config == extension_config + + +def test_psqlpy_extension_config() -> None: + """Test PsqlpyConfig accepts and stores extension_config.""" + extension_config = {"litestar": {"extra_rollback_statuses": {400, 500}}} + + config = PsqlpyConfig(pool_config={"host": "localhost", "db_name": "test"}, extension_config=extension_config) + + assert config.extension_config == extension_config + + +def test_oracle_sync_extension_config() -> None: + """Test OracleSyncConfig accepts and stores extension_config.""" + extension_config = {"litestar": {"enable_correlation_middleware": True}} + + config = OracleSyncConfig(pool_config={"user": "test", "password": "test"}, extension_config=extension_config) + + assert config.extension_config == extension_config + + +def test_oracle_async_extension_config() -> None: + """Test OracleAsyncConfig accepts and stores extension_config.""" + extension_config = {"litestar": {"connection_key": "oracle_async"}} + + config = OracleAsyncConfig(pool_config={"user": "test", "password": "test"}, extension_config=extension_config) + + assert config.extension_config == extension_config + + +def test_adbc_extension_config() -> None: + """Test AdbcConfig accepts and stores extension_config.""" + extension_config = {"litestar": {"session_key": "adbc_session"}} + + config = AdbcConfig( + connection_config={"driver_name": "sqlite", "uri": "sqlite://:memory:"}, extension_config=extension_config + ) + + assert config.extension_config == extension_config + + +def test_bigquery_extension_config() -> None: + """Test BigQueryConfig accepts and stores extension_config.""" + extension_config = {"litestar": {"pool_key": "bigquery_pool"}} + + config = BigQueryConfig(connection_config={"project": "test-project"}, extension_config=extension_config) + + assert config.extension_config == extension_config + + +def test_extension_config_defaults_to_empty_dict() -> None: + """Test that extension_config defaults to empty dict when not provided.""" + configs = [ + SqliteConfig(pool_config={"database": ":memory:"}), + DuckDBConfig(pool_config={"database": ":memory:"}), + AiosqliteConfig(pool_config={"database": ":memory:"}), + AsyncpgConfig(pool_config={"host": "localhost"}), + PsycopgSyncConfig(pool_config={"host": "localhost"}), + PsycopgAsyncConfig(pool_config={"host": "localhost"}), + AsyncmyConfig(pool_config={"host": "localhost"}), + PsqlpyConfig(pool_config={"host": "localhost"}), + OracleSyncConfig(pool_config={"user": "test", "password": "test"}), + OracleAsyncConfig(pool_config={"user": "test", "password": "test"}), + AdbcConfig(connection_config={"driver_name": "sqlite", "uri": "sqlite://:memory:"}), + BigQueryConfig(connection_config={"project": "test"}), + ] + + for config in configs: + assert hasattr(config, "extension_config") + assert config.extension_config == {} + + +def test_extension_config_with_multiple_extensions() -> None: + """Test extension_config can hold multiple extension configurations.""" + extension_config: dict[str, dict[str, Any]] = { + "litestar": {"session_key": "db_session", "commit_mode": "manual"}, + "custom_extension": {"setting1": "value1", "setting2": 42}, + "another_ext": {"enabled": True}, + } + + config = SqliteConfig(pool_config={"database": ":memory:"}, extension_config=extension_config) + + assert config.extension_config == extension_config + assert len(config.extension_config) == 3 + assert "litestar" in config.extension_config + assert "custom_extension" in config.extension_config + assert "another_ext" in config.extension_config + + +@pytest.mark.parametrize( + "config_class,init_kwargs", + [ + (SqliteConfig, {"pool_config": {"database": ":memory:"}}), + (AiosqliteConfig, {"pool_config": {"database": ":memory:"}}), + (DuckDBConfig, {"pool_config": {"database": ":memory:"}}), + (AsyncpgConfig, {"pool_config": {"host": "localhost"}}), + (PsycopgSyncConfig, {"pool_config": {"host": "localhost"}}), + (PsycopgAsyncConfig, {"pool_config": {"host": "localhost"}}), + (AsyncmyConfig, {"pool_config": {"host": "localhost"}}), + (PsqlpyConfig, {"pool_config": {"host": "localhost"}}), + (OracleSyncConfig, {"pool_config": {"user": "test", "password": "test"}}), + (OracleAsyncConfig, {"pool_config": {"user": "test", "password": "test"}}), + (AdbcConfig, {"connection_config": {"driver_name": "sqlite", "uri": "sqlite://:memory:"}}), + (BigQueryConfig, {"connection_config": {"project": "test"}}), + ], +) +def test_all_adapters_accept_extension_config(config_class: type, init_kwargs: dict) -> None: + """Parameterized test ensuring all adapters accept extension_config.""" + extension_config = {"test_extension": {"test_key": "test_value"}} + + config = config_class(**init_kwargs, extension_config=extension_config) + + assert hasattr(config, "extension_config") + assert config.extension_config == extension_config diff --git a/tests/unit/test_extensions/test_litestar/test_handlers.py b/tests/unit/test_extensions/test_litestar/test_handlers.py index 518b7961..8af30136 100644 --- a/tests/unit/test_extensions/test_litestar/test_handlers.py +++ b/tests/unit/test_extensions/test_litestar/test_handlers.py @@ -22,7 +22,6 @@ from litestar.types import Message, Scope -@pytest.mark.asyncio async def test_async_manual_handler_closes_connection() -> None: """Test async manual handler closes connection on terminus event.""" connection_key = "test_connection" @@ -42,7 +41,6 @@ async def test_async_manual_handler_closes_connection() -> None: assert get_sqlspec_scope_state(scope, connection_key) is None -@pytest.mark.asyncio async def test_async_manual_handler_ignores_non_terminus_events() -> None: """Test async manual handler ignores non-terminus events.""" connection_key = "test_connection" @@ -62,7 +60,6 @@ async def test_async_manual_handler_ignores_non_terminus_events() -> None: assert get_sqlspec_scope_state(scope, connection_key) is mock_connection -@pytest.mark.asyncio async def test_async_autocommit_handler_commits_on_success() -> None: """Test async autocommit handler commits on 2xx status.""" connection_key = "test_connection" @@ -85,7 +82,6 @@ async def test_async_autocommit_handler_commits_on_success() -> None: mock_connection.close.assert_awaited_once() -@pytest.mark.asyncio async def test_async_autocommit_handler_rolls_back_on_error() -> None: """Test async autocommit handler rolls back on 4xx/5xx status.""" connection_key = "test_connection" @@ -108,7 +104,6 @@ async def test_async_autocommit_handler_rolls_back_on_error() -> None: mock_connection.close.assert_awaited_once() -@pytest.mark.asyncio async def test_async_autocommit_handler_with_redirect_commit() -> None: """Test async autocommit handler commits on 3xx when enabled.""" connection_key = "test_connection" @@ -129,7 +124,6 @@ async def test_async_autocommit_handler_with_redirect_commit() -> None: mock_connection.rollback.assert_not_awaited() -@pytest.mark.asyncio async def test_async_autocommit_handler_extra_commit_statuses() -> None: """Test async autocommit handler uses extra commit statuses.""" connection_key = "test_connection" @@ -150,7 +144,6 @@ async def test_async_autocommit_handler_extra_commit_statuses() -> None: mock_connection.rollback.assert_not_awaited() -@pytest.mark.asyncio async def test_async_autocommit_handler_raises_on_conflicting_statuses() -> None: """Test async autocommit handler raises error when status sets overlap.""" with pytest.raises(ImproperConfigurationError) as exc_info: @@ -159,7 +152,6 @@ async def test_async_autocommit_handler_raises_on_conflicting_statuses() -> None assert "must not share" in str(exc_info.value) -@pytest.mark.asyncio async def test_async_lifespan_handler_creates_and_closes_pool() -> None: """Test async lifespan handler manages pool lifecycle.""" config = AiosqliteConfig(pool_config={"database": ":memory:"}) @@ -179,7 +171,6 @@ async def test_async_lifespan_handler_creates_and_closes_pool() -> None: assert pool_key not in mock_app.state -@pytest.mark.asyncio async def test_async_pool_provider_returns_pool() -> None: """Test async pool provider returns pool from state.""" config = AiosqliteConfig(pool_config={"database": ":memory:"}) @@ -198,7 +189,6 @@ async def test_async_pool_provider_returns_pool() -> None: state.get.assert_called_once_with(pool_key) -@pytest.mark.asyncio async def test_async_pool_provider_raises_when_pool_missing() -> None: """Test async pool provider raises error when pool not in state.""" config = AiosqliteConfig(pool_config={"database": ":memory:"}) @@ -217,7 +207,6 @@ async def test_async_pool_provider_raises_when_pool_missing() -> None: assert "not found in application state" in str(exc_info.value) -@pytest.mark.asyncio async def test_async_connection_provider_creates_connection() -> None: """Test async connection provider creates connection from pool.""" config = AiosqliteConfig(pool_config={"database": ":memory:"}) @@ -237,7 +226,6 @@ async def test_async_connection_provider_creates_connection() -> None: assert get_sqlspec_scope_state(scope, connection_key) is connection -@pytest.mark.asyncio async def test_async_connection_provider_raises_when_pool_missing() -> None: """Test async connection provider raises error when pool missing.""" config = AiosqliteConfig(pool_config={"database": ":memory:"}) @@ -257,7 +245,6 @@ async def test_async_connection_provider_raises_when_pool_missing() -> None: assert pool_key in str(exc_info.value) -@pytest.mark.asyncio async def test_async_session_provider_creates_session() -> None: """Test async session provider creates driver session.""" config = AiosqliteConfig(pool_config={"database": ":memory:"}) diff --git a/tests/unit/test_utils/test_correlation.py b/tests/unit/test_utils/test_correlation.py index 20c2bc7c..6fa198b4 100644 --- a/tests/unit/test_utils/test_correlation.py +++ b/tests/unit/test_utils/test_correlation.py @@ -314,7 +314,6 @@ def operation(name: str) -> None: assert results[2]["correlation_id"] == "request-123" -@pytest.mark.asyncio async def test_async_context_preservation() -> None: """Test that correlation context is preserved across async operations.""" diff --git a/tests/unit/test_utils/test_fixtures.py b/tests/unit/test_utils/test_fixtures.py index 0c4eeead..2f9cd863 100644 --- a/tests/unit/test_utils/test_fixtures.py +++ b/tests/unit/test_utils/test_fixtures.py @@ -297,7 +297,6 @@ def test_open_fixture_invalid_json() -> None: open_fixture(fixtures_path, "invalid") -@pytest.mark.asyncio async def test_open_fixture_async_valid_file() -> None: """Test open_fixture_async with valid JSON fixture file.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -312,7 +311,6 @@ async def test_open_fixture_async_valid_file() -> None: assert result == test_data -@pytest.mark.asyncio async def test_open_fixture_async_gzipped() -> None: """Test open_fixture_async with gzipped file.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -327,7 +325,6 @@ async def test_open_fixture_async_gzipped() -> None: assert result == test_data -@pytest.mark.asyncio async def test_open_fixture_async_zipped() -> None: """Test open_fixture_async with zipped file.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -342,7 +339,6 @@ async def test_open_fixture_async_zipped() -> None: assert result == test_data -@pytest.mark.asyncio async def test_open_fixture_async_missing_file() -> None: """Test open_fixture_async with missing fixture file.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -417,7 +413,6 @@ def test_write_fixture_with_custom_backend(mock_registry: Mock) -> None: mock_storage.write_text.assert_called_once() -@pytest.mark.asyncio async def test_write_fixture_async_dict() -> None: """Test async writing a dictionary fixture.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -430,7 +425,6 @@ async def test_write_fixture_async_dict() -> None: assert loaded_data == test_data -@pytest.mark.asyncio async def test_write_fixture_async_compressed() -> None: """Test async writing a compressed fixture.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -447,7 +441,6 @@ async def test_write_fixture_async_compressed() -> None: assert loaded_data == test_data -@pytest.mark.asyncio async def test_write_fixture_async_storage_error() -> None: """Test async error handling for invalid storage backend.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -457,7 +450,6 @@ async def test_write_fixture_async_storage_error() -> None: await write_fixture_async(temp_dir, "test", test_data, storage_backend="invalid://backend") -@pytest.mark.asyncio @patch("sqlspec.utils.fixtures.storage_registry") async def test_write_fixture_async_custom_backend(mock_registry: Mock) -> None: """Test async write_fixture with custom storage backend.""" @@ -492,7 +484,6 @@ def test_write_read_roundtrip() -> None: assert loaded_data == original_data -@pytest.mark.asyncio async def test_async_write_read_roundtrip() -> None: """Test complete async write and read roundtrip.""" with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tests/unit/test_utils/test_sync_tools.py b/tests/unit/test_utils/test_sync_tools.py index 6a5c5070..41fd73a9 100644 --- a/tests/unit/test_utils/test_sync_tools.py +++ b/tests/unit/test_utils/test_sync_tools.py @@ -39,7 +39,6 @@ def test_capacity_limiter_property_setter() -> None: assert limiter.total_tokens == 10 -@pytest.mark.asyncio async def test_capacity_limiter_async_context() -> None: """Test CapacityLimiter as async context manager.""" limiter = CapacityLimiter(1) @@ -50,7 +49,6 @@ async def test_capacity_limiter_async_context() -> None: assert limiter._semaphore._value == 1 -@pytest.mark.asyncio async def test_capacity_limiter_acquire_release() -> None: """Test CapacityLimiter manual acquire/release.""" limiter = CapacityLimiter(1) @@ -62,7 +60,6 @@ async def test_capacity_limiter_acquire_release() -> None: assert limiter._semaphore._value == 1 -@pytest.mark.asyncio async def test_capacity_limiter_concurrent_access_edge_cases() -> None: """Test CapacityLimiter with edge case concurrent scenarios.""" limiter = CapacityLimiter(1) @@ -176,7 +173,6 @@ async def simple_async_func(x: int) -> int: sync_func_strict(21) -@pytest.mark.asyncio async def test_async_basic() -> None: """Test async_ decorator basic functionality.""" @@ -188,7 +184,6 @@ def sync_function(x: int) -> int: assert result == 12 -@pytest.mark.asyncio async def test_async_with_limiter() -> None: """Test async_ decorator with custom limiter.""" limiter = CapacityLimiter(1) @@ -201,7 +196,6 @@ def sync_function(x: int) -> int: assert result == 10 -@pytest.mark.asyncio async def test_ensure_async_with_async_function() -> None: """Test ensure_async_ with already async function.""" @@ -213,7 +207,6 @@ async def already_async(x: int) -> int: assert result == 12 -@pytest.mark.asyncio async def test_ensure_async_with_sync_function() -> None: """Test ensure_async_ with sync function.""" @@ -225,7 +218,6 @@ def sync_function(x: int) -> int: assert result == 21 -@pytest.mark.asyncio async def test_ensure_async_exception_propagation() -> None: """Test ensure_async_ properly propagates exceptions.""" @@ -237,7 +229,6 @@ def sync_func_that_raises() -> None: await sync_func_that_raises() -@pytest.mark.asyncio async def test_with_ensure_async_context_manager() -> None: """Test with_ensure_async_ with sync context manager.""" @@ -263,7 +254,6 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: assert result.exited is True -@pytest.mark.asyncio async def test_with_ensure_async_async_context_manager() -> None: """Test with_ensure_async_ with already async context manager.""" @@ -289,7 +279,6 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: assert result.exited is True -@pytest.mark.asyncio async def test_get_next_basic() -> None: """Test get_next with async iterator.""" @@ -317,7 +306,6 @@ async def __anext__(self) -> int: assert result2 == 2 -@pytest.mark.asyncio async def test_get_next_with_default() -> None: """Test get_next with default value when iterator is exhausted.""" @@ -334,7 +322,6 @@ async def __anext__(self) -> int: assert result == "default_value" -@pytest.mark.asyncio async def test_get_next_no_default_behavior() -> None: """Test get_next behavior when iterator is exhausted without default.""" @@ -372,7 +359,6 @@ async def async_function_with_error() -> None: async_function_with_error() -@pytest.mark.asyncio async def test_async_tools_integration() -> None: """Test async tools work together."""