diff --git a/backend/account_v2/organization.py b/backend/account_v2/organization.py index a8586924d5..eea5cd596c 100644 --- a/backend/account_v2/organization.py +++ b/backend/account_v2/organization.py @@ -1,6 +1,7 @@ import logging from django.db import IntegrityError +from utils.db_retry import db_retry from account_v2.models import Organization @@ -12,6 +13,7 @@ def __init__(self): # type: ignore pass @staticmethod + @db_retry() # Add retry for connection drops during organization lookup def get_organization_by_org_id(org_id: str) -> Organization | None: try: return Organization.objects.get(organization_id=org_id) # type: ignore @@ -19,6 +21,7 @@ def get_organization_by_org_id(org_id: str) -> Organization | None: return None @staticmethod + @db_retry() # Add retry for connection drops during organization creation def create_organization( name: str, display_name: str, organization_id: str ) -> Organization: diff --git a/backend/backend/celery_config.py b/backend/backend/celery_config.py index f8833556e7..80c2299fc0 100644 --- a/backend/backend/celery_config.py +++ b/backend/backend/celery_config.py @@ -1,10 +1,19 @@ +import logging +import os from urllib.parse import quote_plus from django.conf import settings +from backend.celery_db_retry import get_celery_db_engine_options, should_use_builtin_retry + +logger = logging.getLogger(__name__) + class CeleryConfig: - """Specifies celery configuration + """Specifies celery configuration with hybrid retry support. + + Supports both custom retry (via patching) and Celery's built-in retry + based on CELERY_USE_BUILTIN_RETRY environment variable. Refer https://docs.celeryq.dev/en/stable/userguide/configuration.html """ @@ -31,3 +40,30 @@ class CeleryConfig: beat_scheduler = "django_celery_beat.schedulers:DatabaseScheduler" task_acks_late = True + + # Database backend engine options for PgBouncer compatibility + result_backend_transport_options = get_celery_db_engine_options() + + # Hybrid retry configuration - built-in vs custom + if should_use_builtin_retry(): + # Use Celery's built-in database backend retry + result_backend_always_retry = ( + os.environ.get("CELERY_RESULT_BACKEND_ALWAYS_RETRY", "true").lower() == "true" + ) + result_backend_max_retries = int( + os.environ.get("CELERY_RESULT_BACKEND_MAX_RETRIES", "3") + ) + result_backend_base_sleep_between_retries_ms = int( + os.environ.get("CELERY_RESULT_BACKEND_BASE_SLEEP_BETWEEN_RETRIES_MS", "1000") + ) + result_backend_max_sleep_between_retries_ms = int( + os.environ.get("CELERY_RESULT_BACKEND_MAX_SLEEP_BETWEEN_RETRIES_MS", "30000") + ) + + logger.info( + f"[Celery Config] Using built-in retry: max_retries={result_backend_max_retries}, " + f"base_sleep={result_backend_base_sleep_between_retries_ms}ms, max_sleep={result_backend_max_sleep_between_retries_ms}ms" + ) + else: + # Custom retry is handled by patch_celery_database_backend() + logger.info("[Celery Config] Using custom retry system (patching enabled)") diff --git a/backend/backend/celery_db_retry.py b/backend/backend/celery_db_retry.py new file mode 100644 index 0000000000..1fb4550dfd --- /dev/null +++ b/backend/backend/celery_db_retry.py @@ -0,0 +1,457 @@ +import logging +import time +from collections.abc import Callable +from functools import wraps +from typing import Any + +from utils.db_constants import ( + DatabaseErrorPatterns, + DatabaseErrorType, + LogMessages, + RetryConfiguration, +) + +logger = logging.getLogger(__name__) + + +def should_use_builtin_retry() -> bool: + """Check if we should use Celery's built-in retry instead of custom retry. + + Returns: + bool: True if built-in retry should be used, False for custom retry + """ + return RetryConfiguration.get_setting_value( + RetryConfiguration.ENV_CELERY_USE_BUILTIN, False + ) + + +def _is_sqlalchemy_error(exception: Exception) -> bool: + """Check if exception is a SQLAlchemy error.""" + try: + from sqlalchemy.exc import OperationalError as SQLAlchemyOperationalError + + return isinstance(exception, SQLAlchemyOperationalError) + except ImportError: + return False + + +def _get_retry_settings_for_error( + error_type: DatabaseErrorType, + retry_count: int, + max_retries: int, + base_delay: float, + max_delay: float, +) -> dict: + """Get appropriate retry settings based on error type.""" + if error_type == DatabaseErrorType.DATABASE_UNAVAILABLE and retry_count == 0: + extended_settings = RetryConfiguration.get_retry_settings(use_extended=True) + return { + "max_retries": extended_settings["max_retries"], + "base_delay": extended_settings["base_delay"], + "max_delay": extended_settings["max_delay"], + } + return { + "max_retries": max_retries, + "base_delay": base_delay, + "max_delay": max_delay, + } + + +def _handle_pool_refresh( + error: BaseException, retry_count: int, total_retries: int +) -> None: + """Handle SQLAlchemy connection pool disposal.""" + logger.warning( + LogMessages.format_message( + LogMessages.POOL_CORRUPTION_DETECTED, + attempt=retry_count, + total=total_retries, + error=error, + ) + ) + try: + _dispose_sqlalchemy_engine() + logger.info("SQLAlchemy connection pool disposed successfully") + except Exception as refresh_error: + logger.warning( + LogMessages.format_message( + LogMessages.POOL_REFRESH_FAILED, + error=refresh_error, + ) + ) + + +def _log_retry_attempt( + error_type: DatabaseErrorType, + retry_count: int, + total_retries: int, + error: BaseException, + delay: float, +) -> None: + """Log retry attempt with appropriate message.""" + if error_type == DatabaseErrorType.DATABASE_UNAVAILABLE: + message = LogMessages.DATABASE_UNAVAILABLE_RETRY + else: + message = LogMessages.CONNECTION_ERROR_RETRY + + logger.warning( + LogMessages.format_message( + message, + attempt=retry_count, + total=total_retries, + error=error, + delay=delay, + ) + ) + + +def _log_non_sqlalchemy_error(error: BaseException) -> None: + """Log non-SQLAlchemy errors.""" + logger.debug(LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=error)) + + +def _log_non_retryable_error(error: BaseException) -> None: + """Log non-retryable errors.""" + logger.debug(LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=error)) + + +def _execute_retry_attempt( + error: BaseException, + error_type: DatabaseErrorType, + needs_refresh: bool, + retry_count: int, + current_settings: dict, +) -> tuple[int, float]: + """Execute a retry attempt and return updated retry_count and delay.""" + delay = min( + current_settings["base_delay"] * (2**retry_count), + current_settings["max_delay"], + ) + retry_count += 1 + + # Handle connection pool refresh if needed + if needs_refresh: + _handle_pool_refresh(error, retry_count, current_settings["max_retries"] + 1) + else: + _log_retry_attempt( + error_type, + retry_count, + current_settings["max_retries"] + 1, + error, + delay, + ) + + return retry_count, delay + + +def _log_max_retries_exceeded(error: BaseException, total_retries: int) -> None: + """Log when max retries are exceeded.""" + logger.error( + LogMessages.format_message( + LogMessages.MAX_RETRIES_EXCEEDED, + total=total_retries, + error=error, + ) + ) + + +def _process_celery_exception( + error: BaseException, + retry_count: int, + max_retries: int, + base_delay: float, + max_delay: float, +) -> tuple[int, float] | None: + """Process a Celery exception and return retry info or None if not retryable. + + Returns: + tuple[int, float]: (updated_retry_count, delay) if should retry + None: if should not retry (will re-raise) + """ + if not _is_sqlalchemy_error(error): + _log_non_sqlalchemy_error(error) + raise + + # Use centralized error classification + error_type, needs_refresh = DatabaseErrorPatterns.classify_error(error) + + if not DatabaseErrorPatterns.is_retryable_error(error_type): + _log_non_retryable_error(error) + raise + + # Get appropriate retry settings for this error type + current_settings = _get_retry_settings_for_error( + error_type, retry_count, max_retries, base_delay, max_delay + ) + + if retry_count < current_settings["max_retries"]: + return _execute_retry_attempt( + error, error_type, needs_refresh, retry_count, current_settings + ) + else: + _log_max_retries_exceeded(error, current_settings["max_retries"] + 1) + raise + + +def _log_success(retry_count: int) -> None: + """Log operation success.""" + if retry_count == 0: + logger.debug(LogMessages.OPERATION_SUCCESS) + else: + logger.info( + LogMessages.format_message( + LogMessages.OPERATION_SUCCESS_AFTER_RETRY, + retry_count=retry_count, + ) + ) + + +def celery_db_retry_with_backoff( + max_retries: int | None = None, + base_delay: float | None = None, + max_delay: float | None = None, +): + """Decorator to retry Celery database backend operations with exponential backoff. + + This is specifically designed for Celery's database result backend operations + that may experience connection drops when using database proxy or database restarts. + + Args: + max_retries: Maximum number of retry attempts (defaults to settings or 3) + base_delay: Initial delay in seconds (defaults to settings or 1.0) + max_delay: Maximum delay in seconds (defaults to settings or 30.0) + """ + # Get defaults from centralized configuration + celery_settings = RetryConfiguration.get_celery_retry_settings() + + if max_retries is None: + max_retries = celery_settings["max_retries"] + if base_delay is None: + base_delay = celery_settings["base_delay"] + if max_delay is None: + max_delay = celery_settings["max_delay"] + + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + logger.debug( + LogMessages.format_message( + LogMessages.CELERY_OPERATION_START, operation=func.__name__ + ) + ) + retry_count = 0 + while retry_count <= max_retries: + try: + result = func(*args, **kwargs) + _log_success(retry_count) + return result + except Exception as e: + retry_info = _process_celery_exception( + e, retry_count, max_retries, base_delay, max_delay + ) + if retry_info: + retry_count, delay = retry_info + time.sleep(delay) + continue + + # This should never be reached, but included for completeness + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def _dispose_sqlalchemy_engine(): + """Dispose SQLAlchemy engine to force connection pool recreation. + + This is called when we detect severe connection issues that require + the entire connection pool to be recreated. + """ + try: + # Try to get the engine from the Celery backend + from celery import current_app + + backend = current_app.backend + if hasattr(backend, "engine") and backend.engine: + backend.engine.dispose() + logger.info("Disposed SQLAlchemy engine for connection pool refresh") + except Exception as e: + logger.warning(f"Could not dispose SQLAlchemy engine: {e}") + + +# Track if patching has been applied to prevent double patching +_celery_backend_patched = False + + +def patch_celery_database_backend(): + """Dynamically patch Celery's database backend classes to add retry logic. + + This function should be called during application initialization to add + connection retry capabilities to Celery's database backend operations. + + Supports hybrid mode: can use either custom retry logic or Celery's built-in retry + based on CELERY_USE_BUILTIN_RETRY environment variable/setting. + """ + global _celery_backend_patched + + # Prevent double patching + if _celery_backend_patched: + logger.debug("Celery database backend already patched, skipping") + return + + # Check if we should use built-in retry instead of custom + if should_use_builtin_retry(): + logger.info( + "Using Celery's built-in database backend retry (CELERY_USE_BUILTIN_RETRY=true)" + ) + _configure_builtin_retry() + _celery_backend_patched = True + return + + logger.info("Using custom database backend retry (CELERY_USE_BUILTIN_RETRY=false)") + + try: + from celery.backends.database import DatabaseBackend + + # Check if already patched by looking for our marker attribute + if hasattr(DatabaseBackend._store_result, "_retry_patched"): + logger.debug( + "Celery database backend already patched (detected marker), skipping" + ) + _celery_backend_patched = True + return + + # Store original methods - bypass Celery's built-in retry by accessing the real methods + # This prevents double-retry conflicts when using custom retry + original_store_result = getattr( + DatabaseBackend._store_result, "__wrapped__", DatabaseBackend._store_result + ) + original_get_task_meta_for = getattr( + DatabaseBackend._get_task_meta_for, + "__wrapped__", + DatabaseBackend._get_task_meta_for, + ) + original_get_result = getattr( + DatabaseBackend.get_result, "__wrapped__", DatabaseBackend.get_result + ) + + logger.info( + "Bypassing Celery's built-in retry system to prevent conflicts (using custom retry)" + ) + + # Apply retry decorators + logger.info("Patching Celery method: _store_result with retry logic") + DatabaseBackend._store_result = celery_db_retry_with_backoff()( + original_store_result + ) + + logger.info("Patching Celery method: _get_task_meta_for with retry logic") + DatabaseBackend._get_task_meta_for = celery_db_retry_with_backoff()( + original_get_task_meta_for + ) + + logger.info("Patching Celery method: get_result with retry logic") + DatabaseBackend.get_result = celery_db_retry_with_backoff()(original_get_result) + + # delete_result may not exist in all Celery versions + if hasattr(DatabaseBackend, "delete_result"): + logger.info("Patching Celery method: delete_result with retry logic") + original_delete_result = getattr( + DatabaseBackend.delete_result, + "__wrapped__", + DatabaseBackend.delete_result, + ) + DatabaseBackend.delete_result = celery_db_retry_with_backoff()( + original_delete_result + ) + else: + logger.info("Celery method delete_result not found, skipping patch") + + # Mark as patched to prevent double patching + DatabaseBackend._store_result._retry_patched = True + _celery_backend_patched = True + + logger.info("Successfully patched Celery database backend with retry logic") + + except ImportError as e: + logger.warning(f"Could not import Celery database backend for patching: {e}") + except Exception as e: + logger.error(f"Error patching Celery database backend: {e}") + + +def _configure_builtin_retry(): + """Configure Celery's built-in database backend retry settings. + + This is called when CELERY_USE_BUILTIN_RETRY=true to setup Celery's + native retry functionality instead of our custom implementation. + """ + try: + import os + + # Get retry configuration from centralized settings + settings = RetryConfiguration.get_retry_settings() + + max_retries = settings["max_retries"] + base_delay_ms = int(settings["base_delay"] * 1000) + max_delay_ms = int(settings["max_delay"] * 1000) + + # Apply built-in retry settings to Celery configuration + logger.info( + f"Configured Celery built-in retry: max_retries={max_retries}, base_delay={base_delay_ms}ms, max_delay={max_delay_ms}ms" + ) + + # Store settings for use in celery_config.py + os.environ["CELERY_RESULT_BACKEND_ALWAYS_RETRY"] = "true" + os.environ["CELERY_RESULT_BACKEND_MAX_RETRIES"] = str(max_retries) + os.environ["CELERY_RESULT_BACKEND_BASE_SLEEP_BETWEEN_RETRIES_MS"] = str( + base_delay_ms + ) + os.environ["CELERY_RESULT_BACKEND_MAX_SLEEP_BETWEEN_RETRIES_MS"] = str( + max_delay_ms + ) + + logger.info("Successfully configured Celery's built-in database backend retry") + + except Exception as e: + logger.error(f"Error configuring Celery built-in retry: {e}") + logger.warning("Falling back to no retry mechanism") + + +def get_celery_db_engine_options(): + """Get SQLAlchemy engine options optimized for database connection pooling. + + Includes built-in retry configuration if CELERY_USE_BUILTIN_RETRY is enabled. + + These options are designed to work well with database proxy connection pooling + without interfering with the database proxy's pool management. + + All options are configurable via environment variables for flexibility. + + Returns: + dict: SQLAlchemy engine options + """ + return { + # Connection health checking + "pool_pre_ping": RetryConfiguration.get_setting_value( + "CELERY_DB_POOL_PRE_PING", True + ), # Test connections before use + # Minimal pooling (let database proxy handle the real pooling) + "pool_size": RetryConfiguration.get_setting_value( + "CELERY_DB_POOL_SIZE", 5 + ), # Small pool since database proxy handles real pooling + "max_overflow": RetryConfiguration.get_setting_value( + "CELERY_DB_MAX_OVERFLOW", 0 + ), # No overflow, rely on database proxy + "pool_recycle": RetryConfiguration.get_setting_value( + "CELERY_DB_POOL_RECYCLE", 3600 + ), # Recycle connections every hour + # Connection timeouts using centralized configuration + "connect_args": { + "connect_timeout": RetryConfiguration.get_setting_value( + "CELERY_DB_CONNECT_TIMEOUT", 30 + ), + }, + # Echo SQL queries if debug logging is enabled + "echo": RetryConfiguration.get_setting_value("CELERY_DB_ECHO_SQL", False), + } diff --git a/backend/backend/celery_service.py b/backend/backend/celery_service.py index 9b0e19b6f0..ec5ddc63d9 100644 --- a/backend/backend/celery_service.py +++ b/backend/backend/celery_service.py @@ -7,6 +7,7 @@ from celery import Celery +from backend.celery_db_retry import patch_celery_database_backend from backend.celery_task import TaskRegistry from backend.settings.base import LOGGING @@ -24,6 +25,9 @@ # Create a Celery instance. Default time zone is UTC. app = Celery("backend") +# Patch Celery database backend to add connection retry logic +patch_celery_database_backend() + # Register custom tasks TaskRegistry() diff --git a/backend/backend/settings/base.py b/backend/backend/settings/base.py index cb18de0ca5..7daf4c9639 100644 --- a/backend/backend/settings/base.py +++ b/backend/backend/settings/base.py @@ -115,6 +115,18 @@ def get_required_setting(setting_key: str, default: str | None = None) -> str | DB_PORT = os.environ.get("DB_PORT", 5432) DB_SCHEMA = os.environ.get("DB_SCHEMA", "unstract") +# Database connection retry settings +DB_CONNECTION_RETRY_COUNT = int(os.environ.get("DB_CONNECTION_RETRY_COUNT", "3")) +DB_CONNECTION_RETRY_DELAY = int(os.environ.get("DB_CONNECTION_RETRY_DELAY", "1")) +DB_CONNECTION_RETRY_MAX_DELAY = int(os.environ.get("DB_CONNECTION_RETRY_MAX_DELAY", "30")) + +# Celery database backend retry settings +CELERY_DB_RETRY_COUNT = int(os.environ.get("CELERY_DB_RETRY_COUNT", "3")) +CELERY_DB_RETRY_DELAY = int(os.environ.get("CELERY_DB_RETRY_DELAY", "1")) +CELERY_DB_RETRY_MAX_DELAY = int(os.environ.get("CELERY_DB_RETRY_MAX_DELAY", "30")) +CELERY_DB_CONNECT_TIMEOUT = int(os.environ.get("CELERY_DB_CONNECT_TIMEOUT", "30")) +CELERY_DB_ECHO_SQL = os.environ.get("CELERY_DB_ECHO_SQL", "False").lower() == "true" + # Celery Backend Database Name (falls back to main DB when unset or empty) CELERY_BACKEND_DB_NAME = os.environ.get("CELERY_BACKEND_DB_NAME") or DB_NAME DEFAULT_ORGANIZATION = "default_org" diff --git a/backend/backend/workers/file_processing/file_processing.py b/backend/backend/workers/file_processing/file_processing.py index b6c1f975e6..997e5c6792 100644 --- a/backend/backend/workers/file_processing/file_processing.py +++ b/backend/backend/workers/file_processing/file_processing.py @@ -7,6 +7,7 @@ from celery import Celery +from backend.celery_db_retry import patch_celery_database_backend from backend.settings.base import LOGGING from backend.workers.constants import CeleryWorkerNames from backend.workers.file_processing.celery_config import CeleryConfig @@ -25,6 +26,9 @@ # Create a Celery instance. Default time zone is UTC. app = Celery(CeleryWorkerNames.FILE_PROCESSING) +# Patch Celery database backend to add connection retry logic +patch_celery_database_backend() + # Load task modules from all registered Django app configs. app.config_from_object(CeleryConfig) diff --git a/backend/backend/workers/file_processing_callback/file_processing_callback.py b/backend/backend/workers/file_processing_callback/file_processing_callback.py index ecc20c6436..764e17fefd 100644 --- a/backend/backend/workers/file_processing_callback/file_processing_callback.py +++ b/backend/backend/workers/file_processing_callback/file_processing_callback.py @@ -7,6 +7,7 @@ from celery import Celery +from backend.celery_db_retry import patch_celery_database_backend from backend.settings.base import LOGGING from backend.workers.constants import CeleryWorkerNames from backend.workers.file_processing_callback.celery_config import CeleryConfig @@ -25,6 +26,9 @@ # Create a Celery instance. Default time zone is UTC. app = Celery(CeleryWorkerNames.FILE_PROCESSING_CALLBACK) +# Patch Celery database backend to add connection retry logic +patch_celery_database_backend() + # Load task modules from all registered Django app configs. app.config_from_object(CeleryConfig) diff --git a/backend/sample.env b/backend/sample.env index dcd79b5a16..816109f979 100644 --- a/backend/sample.env +++ b/backend/sample.env @@ -25,6 +25,24 @@ DB_SCHEMA="unstract" # Celery Backend Database (optional - defaults to DB_NAME if unset) # Example: # CELERY_BACKEND_DB_NAME=unstract_celery_db + +# Database connection retry settings (for handling connection drops with database proxy) +# DB_CONNECTION_RETRY_COUNT=3 +# DB_CONNECTION_RETRY_DELAY=1 +# DB_CONNECTION_RETRY_MAX_DELAY=30 + +# Celery database backend retry settings (for handling connection drops with database proxy) +# CELERY_DB_RETRY_COUNT=3 +# CELERY_DB_RETRY_DELAY=1 +# CELERY_DB_RETRY_MAX_DELAY=30 + +# Celery database backend connection pool settings +# CELERY_DB_POOL_PRE_PING=True # Test connections before use +# CELERY_DB_POOL_SIZE=5 # Connection pool size +# CELERY_DB_MAX_OVERFLOW=0 # Max overflow connections +# CELERY_DB_POOL_RECYCLE=3600 # Recycle connections after (seconds) +# CELERY_DB_CONNECT_TIMEOUT=30 # Connection timeout (seconds) +# CELERY_DB_ECHO_SQL=False # Echo SQL queries for debugging # Redis REDIS_HOST="unstract-redis" REDIS_PORT=6379 diff --git a/backend/usage_v2/helper.py b/backend/usage_v2/helper.py index 9452547acb..29e3d108ff 100644 --- a/backend/usage_v2/helper.py +++ b/backend/usage_v2/helper.py @@ -4,6 +4,7 @@ from django.db.models import QuerySet, Sum from rest_framework.exceptions import APIException +from utils.db_retry import db_retry from .constants import UsageKeys from .models import Usage @@ -13,6 +14,7 @@ class UsageHelper: @staticmethod + @db_retry() # Add retry for connection drops during usage aggregation def get_aggregated_token_count(run_id: str) -> dict: """Retrieve aggregated token counts for the given run_id. @@ -64,6 +66,7 @@ def get_aggregated_token_count(run_id: str) -> dict: raise APIException("Error while aggregating token counts") @staticmethod + @db_retry() # Add retry for connection drops during metrics aggregation def aggregate_usage_metrics(queryset: QuerySet) -> dict[str, Any]: """Aggregate usage metrics from a queryset of Usage objects. diff --git a/backend/utils/db_constants.py b/backend/utils/db_constants.py new file mode 100644 index 0000000000..0796a78f38 --- /dev/null +++ b/backend/utils/db_constants.py @@ -0,0 +1,329 @@ +"""Database Error Constants and Configuration + +Centralized error patterns, types, and configuration for database retry mechanisms. +Used by both Django ORM retry (db_retry.py) and Celery SQLAlchemy retry (celery_db_retry.py). +""" + +import os +from dataclasses import dataclass +from enum import Enum + + +class DatabaseErrorType(Enum): + """Classification of database connection errors.""" + + STALE_CONNECTION = "stale_connection" + DATABASE_UNAVAILABLE = "database_unavailable" + TRANSIENT_ERROR = "transient_error" + NON_CONNECTION_ERROR = "non_connection_error" + + +class ConnectionPoolType(Enum): + """Types of connection pools that need different refresh strategies.""" + + DJANGO_ORM = "django_orm" + SQLALCHEMY = "sqlalchemy" + + +@dataclass(frozen=True) +class ErrorPattern: + """Defines an error pattern with its classification and handling strategy.""" + + keywords: tuple[str, ...] + error_type: DatabaseErrorType + requires_pool_refresh: bool + description: str + + +class DatabaseErrorPatterns: + """Centralized database error patterns for consistent classification.""" + + # Stale connection patterns - connection exists but is dead + STALE_CONNECTION_PATTERNS = ( + ErrorPattern( + keywords=("connection already closed",), + error_type=DatabaseErrorType.STALE_CONNECTION, + requires_pool_refresh=True, + description="Django InterfaceError - connection pool corruption", + ), + ErrorPattern( + keywords=("server closed the connection unexpectedly",), + error_type=DatabaseErrorType.STALE_CONNECTION, + requires_pool_refresh=True, + description="PostgreSQL dropped connection, pool may be stale", + ), + ErrorPattern( + keywords=("connection was lost",), + error_type=DatabaseErrorType.STALE_CONNECTION, + requires_pool_refresh=True, + description="Connection lost during operation", + ), + ) + + # Database unavailable patterns - database is completely unreachable + DATABASE_UNAVAILABLE_PATTERNS = ( + ErrorPattern( + keywords=("connection refused",), + error_type=DatabaseErrorType.DATABASE_UNAVAILABLE, + requires_pool_refresh=True, + description="Database server is down or unreachable", + ), + ErrorPattern( + keywords=("could not connect",), + error_type=DatabaseErrorType.DATABASE_UNAVAILABLE, + requires_pool_refresh=True, + description="Unable to establish database connection", + ), + ErrorPattern( + keywords=("no route to host",), + error_type=DatabaseErrorType.DATABASE_UNAVAILABLE, + requires_pool_refresh=True, + description="Network routing issue to database", + ), + ) + + # Transient error patterns - temporary issues, retry without pool refresh + TRANSIENT_ERROR_PATTERNS = ( + ErrorPattern( + keywords=("connection timeout",), + error_type=DatabaseErrorType.TRANSIENT_ERROR, + requires_pool_refresh=False, + description="Connection attempt timed out", + ), + ErrorPattern( + keywords=("connection pool exhausted",), + error_type=DatabaseErrorType.TRANSIENT_ERROR, + requires_pool_refresh=False, + description="Connection pool temporarily full", + ), + ErrorPattern( + keywords=("connection closed",), + error_type=DatabaseErrorType.TRANSIENT_ERROR, + requires_pool_refresh=False, + description="Connection closed during operation", + ), + # MySQL compatibility patterns + ErrorPattern( + keywords=("lost connection to mysql server",), + error_type=DatabaseErrorType.TRANSIENT_ERROR, + requires_pool_refresh=False, + description="MySQL connection lost", + ), + ErrorPattern( + keywords=("mysql server has gone away",), + error_type=DatabaseErrorType.TRANSIENT_ERROR, + requires_pool_refresh=False, + description="MySQL server disconnected", + ), + ) + + # All patterns combined for easy iteration + ALL_PATTERNS = ( + STALE_CONNECTION_PATTERNS + + DATABASE_UNAVAILABLE_PATTERNS + + TRANSIENT_ERROR_PATTERNS + ) + + @classmethod + def classify_error( + cls, error: Exception, error_message: str = None + ) -> tuple[DatabaseErrorType, bool]: + """Classify a database error and determine if pool refresh is needed. + + Args: + error: The exception to classify + error_message: Optional pre-lowercased error message for efficiency + + Returns: + Tuple[DatabaseErrorType, bool]: (error_type, needs_pool_refresh) + """ + if error_message is None: + error_message = str(error).lower() + else: + error_message = error_message.lower() + + # Check all patterns for matches + for pattern in cls.ALL_PATTERNS: + if any(keyword in error_message for keyword in pattern.keywords): + return pattern.error_type, pattern.requires_pool_refresh + + # No pattern matched + return DatabaseErrorType.NON_CONNECTION_ERROR, False + + @classmethod + def is_retryable_error(cls, error_type: DatabaseErrorType) -> bool: + """Check if an error type should be retried.""" + return error_type in { + DatabaseErrorType.STALE_CONNECTION, + DatabaseErrorType.DATABASE_UNAVAILABLE, + DatabaseErrorType.TRANSIENT_ERROR, + } + + @classmethod + def get_all_error_keywords(cls) -> list[str]: + """Get all error keywords as a flat list (for backward compatibility).""" + keywords = [] + for pattern in cls.ALL_PATTERNS: + keywords.extend(pattern.keywords) + return keywords + + +class RetryConfiguration: + """Centralized configuration for database retry settings.""" + + # Environment variable names + ENV_MAX_RETRIES = "DB_RETRY_MAX_RETRIES" + ENV_BASE_DELAY = "DB_RETRY_BASE_DELAY" + ENV_MAX_DELAY = "DB_RETRY_MAX_DELAY" + ENV_FORCE_REFRESH = "DB_RETRY_FORCE_REFRESH" + + # Celery-specific environment variables + ENV_CELERY_USE_BUILTIN = "CELERY_USE_BUILTIN_RETRY" + ENV_CELERY_MAX_RETRIES = "CELERY_DB_RETRY_COUNT" + ENV_CELERY_BASE_DELAY = "CELERY_DB_RETRY_DELAY" + ENV_CELERY_MAX_DELAY = "CELERY_DB_RETRY_MAX_DELAY" + + # Default values + DEFAULT_MAX_RETRIES = 3 + DEFAULT_BASE_DELAY = 1.0 + DEFAULT_MAX_DELAY = 30.0 + DEFAULT_FORCE_REFRESH = True + + # Extended retry for database unavailable scenarios + DEFAULT_EXTENDED_MAX_RETRIES = 8 + DEFAULT_EXTENDED_BASE_DELAY = 2.0 + DEFAULT_EXTENDED_MAX_DELAY = 60.0 + + @classmethod + def get_setting_value(cls, setting_name: str, default_value, use_django=True): + """Get setting value from Django settings or environment with proper type conversion. + + Args: + setting_name: Name of the setting + default_value: Default value (determines return type) + use_django: Whether to try Django settings first + + Returns: + Setting value converted to same type as default_value + """ + # Try Django settings first (if available and requested) + if use_django: + try: + from django.conf import settings + from django.core.exceptions import ImproperlyConfigured + + if hasattr(settings, setting_name): + return getattr(settings, setting_name) + except (ImportError, ImproperlyConfigured): + # Django not available or not configured + pass + + # Fall back to environment variables + env_value = os.environ.get(setting_name) + if env_value is not None: + return cls._convert_env_value(env_value, default_value) + + return default_value + + @classmethod + def _convert_env_value(cls, env_value: str, default_value): + """Convert environment variable string to appropriate type.""" + try: + # Check bool first since bool is subclass of int in Python + if isinstance(default_value, bool): + return env_value.lower() in ("true", "1", "yes", "on") + elif isinstance(default_value, int): + return int(env_value) + elif isinstance(default_value, float): + return float(env_value) + else: + return env_value + except (ValueError, TypeError): + return default_value + + @classmethod + def get_retry_settings(cls, use_extended=False, use_django=True) -> dict: + """Get complete retry configuration. + + Args: + use_extended: Use extended settings for database unavailable scenarios + use_django: Whether to check Django settings + + Returns: + Dict with retry configuration + """ + if use_extended: + max_retries_default = cls.DEFAULT_EXTENDED_MAX_RETRIES + base_delay_default = cls.DEFAULT_EXTENDED_BASE_DELAY + max_delay_default = cls.DEFAULT_EXTENDED_MAX_DELAY + else: + max_retries_default = cls.DEFAULT_MAX_RETRIES + base_delay_default = cls.DEFAULT_BASE_DELAY + max_delay_default = cls.DEFAULT_MAX_DELAY + + return { + "max_retries": cls.get_setting_value( + cls.ENV_MAX_RETRIES, max_retries_default, use_django + ), + "base_delay": cls.get_setting_value( + cls.ENV_BASE_DELAY, base_delay_default, use_django + ), + "max_delay": cls.get_setting_value( + cls.ENV_MAX_DELAY, max_delay_default, use_django + ), + "force_refresh": cls.get_setting_value( + cls.ENV_FORCE_REFRESH, cls.DEFAULT_FORCE_REFRESH, use_django + ), + } + + @classmethod + def get_celery_retry_settings(cls, use_django=True) -> dict: + """Get Celery-specific retry configuration.""" + return { + "max_retries": cls.get_setting_value( + cls.ENV_CELERY_MAX_RETRIES, cls.DEFAULT_MAX_RETRIES, use_django + ), + "base_delay": cls.get_setting_value( + cls.ENV_CELERY_BASE_DELAY, cls.DEFAULT_BASE_DELAY, use_django + ), + "max_delay": cls.get_setting_value( + cls.ENV_CELERY_MAX_DELAY, cls.DEFAULT_MAX_DELAY, use_django + ), + "use_builtin": cls.get_setting_value( + cls.ENV_CELERY_USE_BUILTIN, False, use_django + ), + } + + +class LogMessages: + """Centralized log message templates for consistent logging.""" + + # Success messages + OPERATION_SUCCESS = "Operation completed successfully" + OPERATION_SUCCESS_AFTER_RETRY = ( + "Operation succeeded after {retry_count} retry attempts" + ) + + # Connection pool refresh messages + POOL_CORRUPTION_DETECTED = "Database connection pool corruption detected (attempt {attempt}/{total}): {error}. Refreshing connection pool..." + POOL_REFRESH_SUCCESS = "Connection pool refreshed successfully" + POOL_REFRESH_FAILED = "Failed to refresh connection pool: {error}" + + # Retry attempt messages + CONNECTION_ERROR_RETRY = "Database connection error (attempt {attempt}/{total}): {error}. Retrying in {delay} seconds..." + DATABASE_UNAVAILABLE_RETRY = "Database unavailable (attempt {attempt}/{total}): {error}. Using extended retry in {delay} seconds..." + + # Failure messages + MAX_RETRIES_EXCEEDED = "Database connection failed after {total} attempts: {error}" + NON_RETRYABLE_ERROR = "Non-retryable error, not retrying: {error}" + + # Execution messages + EXECUTING_WITH_RETRY = "Executing operation with retry logic..." + CELERY_OPERATION_START = ( + "Executing Celery DB operation: {operation} (retry mechanism active)" + ) + + @classmethod + def format_message(cls, template: str, **kwargs) -> str: + """Format a log message template with provided kwargs.""" + return template.format(**kwargs) diff --git a/backend/utils/db_retry.py b/backend/utils/db_retry.py new file mode 100644 index 0000000000..66e264db23 --- /dev/null +++ b/backend/utils/db_retry.py @@ -0,0 +1,516 @@ +"""Database Retry Utility + +A simple, focused utility for retrying Django ORM operations that fail due to +connection issues. Designed for use in models, views, services, and anywhere +database operations might encounter connection drops. + +Features: +- Automatic connection pool refresh for stale connections +- Extended retry logic for database unavailable scenarios +- Centralized error classification and handling +- Environment variable configuration +- Multiple usage patterns: decorators, context managers, direct calls + +Usage Examples: + + # Decorator usage + @db_retry(max_retries=3) + def my_database_operation(): + Model.objects.create(...) + + # Context manager usage + with db_retry_context(max_retries=3): + model.save() + other_model.delete() + + # Direct function usage + result = retry_database_operation( + lambda: MyModel.objects.filter(...).update(...), + max_retries=3 + ) +""" + +import logging +import time +from collections.abc import Callable +from contextlib import contextmanager +from functools import wraps +from typing import Any + +from django.db import close_old_connections +from django.db.utils import InterfaceError, OperationalError + +from utils.db_constants import ( + DatabaseErrorPatterns, + DatabaseErrorType, + LogMessages, + RetryConfiguration, +) + +logger = logging.getLogger(__name__) + + +# Legacy function for backward compatibility - now delegates to RetryConfiguration +def get_retry_setting(setting_name: str, default_value): + """Get retry setting from Django settings or environment variables with fallback to default. + + Args: + setting_name: The setting name to look for + default_value: Default value if setting is not found + + Returns: + The setting value (int or float based on default_value type) + """ + return RetryConfiguration.get_setting_value(setting_name, default_value) + + +def get_default_retry_settings(use_extended=False): + """Get default retry settings from environment or use built-in defaults. + + Environment variables: + DB_RETRY_MAX_RETRIES: Maximum number of retry attempts (default: 3 or 8 if extended) + DB_RETRY_BASE_DELAY: Initial delay between retries in seconds (default: 1.0 or 2.0 if extended) + DB_RETRY_MAX_DELAY: Maximum delay between retries in seconds (default: 30.0 or 60.0 if extended) + DB_RETRY_FORCE_REFRESH: Force connection pool refresh on stale connections (default: True) + + Args: + use_extended: Use extended retry settings for database unavailable scenarios + + Returns: + dict: Dictionary with retry settings + """ + return RetryConfiguration.get_retry_settings(use_extended=use_extended) + + +def _classify_database_error(error: Exception): + """Classify a database error using centralized error patterns. + + Args: + error: The exception to classify + + Returns: + Tuple containing (error_type, needs_refresh, use_extended_retry) + """ + # Only classify Django database errors + if not isinstance(error, (OperationalError, InterfaceError)): + return DatabaseErrorType.NON_CONNECTION_ERROR, False, False + + error_type, needs_refresh = DatabaseErrorPatterns.classify_error(error) + + # For InterfaceError with stale connection patterns, always refresh + if ( + isinstance(error, InterfaceError) + and error_type == DatabaseErrorType.STALE_CONNECTION + ): + needs_refresh = True + + # Use extended retry for database unavailable scenarios + use_extended_retry = error_type == DatabaseErrorType.DATABASE_UNAVAILABLE + + return error_type, needs_refresh, use_extended_retry + + +def _update_retry_settings_for_extended(retry_count: int, use_extended: bool) -> dict: + """Get extended retry settings if needed for database unavailable errors.""" + if use_extended and retry_count == 0: + return get_default_retry_settings(use_extended=True) + return {} + + +def _handle_connection_pool_refresh( + error: BaseException, retry_count: int, max_retries: int +) -> None: + """Handle connection pool refresh for stale connections.""" + logger.warning( + LogMessages.format_message( + LogMessages.POOL_CORRUPTION_DETECTED, + attempt=retry_count, + total=max_retries + 1, + error=error, + ) + ) + try: + close_old_connections() + logger.info(LogMessages.POOL_REFRESH_SUCCESS) + except Exception as refresh_error: + logger.warning( + LogMessages.format_message( + LogMessages.POOL_REFRESH_FAILED, error=refresh_error + ) + ) + + +def _log_retry_attempt( + error_type: DatabaseErrorType, + retry_count: int, + max_retries: int, + error: BaseException, + delay: float, +) -> None: + """Log retry attempt with appropriate message based on error type.""" + if error_type == DatabaseErrorType.DATABASE_UNAVAILABLE: + message = LogMessages.DATABASE_UNAVAILABLE_RETRY + else: + message = LogMessages.CONNECTION_ERROR_RETRY + + logger.warning( + LogMessages.format_message( + message, + attempt=retry_count, + total=max_retries + 1, + error=error, + delay=delay, + ) + ) + + +def _handle_retry_attempt( + error: BaseException, + error_type: DatabaseErrorType, + needs_refresh: bool, + force_refresh: bool, + retry_count: int, + max_retries: int, + delay: float, +) -> None: + """Handle a single retry attempt - either refresh pool or log retry.""" + if needs_refresh and force_refresh: + _handle_connection_pool_refresh(error, retry_count, max_retries) + else: + _log_retry_attempt(error_type, retry_count, max_retries, error, delay) + + +def _execute_with_retry( + operation: Callable, + max_retries: int = 3, + base_delay: float = 1.0, + max_delay: float = 30.0, + force_refresh: bool = True, +) -> Any: + """Execute an operation with retry logic for connection errors. + + Args: + operation: The operation to execute + max_retries: Maximum number of retry attempts + base_delay: Initial delay between retries in seconds + max_delay: Maximum delay between retries in seconds + force_refresh: Whether to force connection pool refresh on stale connections + + Returns: + The result of the operation + + Raises: + The original exception if max retries exceeded or non-connection error + """ + retry_count = 0 + + while retry_count <= max_retries: + try: + logger.debug(LogMessages.EXECUTING_WITH_RETRY) + return operation() + except Exception as e: + error_type, needs_refresh, use_extended = _classify_database_error(e) + + if not DatabaseErrorPatterns.is_retryable_error(error_type): + logger.debug( + LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=e) + ) + raise + + # Update settings for extended retry if needed + extended_settings = _update_retry_settings_for_extended( + retry_count, use_extended + ) + if extended_settings: + max_retries = extended_settings["max_retries"] + base_delay = extended_settings["base_delay"] + max_delay = extended_settings["max_delay"] + + if retry_count < max_retries: + delay = min(base_delay * (2**retry_count), max_delay) + retry_count += 1 + + _handle_retry_attempt( + e, + error_type, + needs_refresh, + force_refresh, + retry_count, + max_retries, + delay, + ) + + time.sleep(delay) + continue + else: + logger.error( + LogMessages.format_message( + LogMessages.MAX_RETRIES_EXCEEDED, total=max_retries + 1, error=e + ) + ) + raise + + # This should never be reached, but included for completeness + return operation() + + +def db_retry( + max_retries: int | None = None, + base_delay: float | None = None, + max_delay: float | None = None, + force_refresh: bool | None = None, +) -> Callable: + """Decorator to retry database operations on connection errors. + + Uses environment variables DB_RETRY_MAX_RETRIES, DB_RETRY_BASE_DELAY, DB_RETRY_MAX_DELAY, + DB_RETRY_FORCE_REFRESH for defaults if parameters are not provided. + + Args: + max_retries: Maximum number of retry attempts (default from env: DB_RETRY_MAX_RETRIES or 3) + base_delay: Initial delay between retries in seconds (default from env: DB_RETRY_BASE_DELAY or 1.0) + max_delay: Maximum delay between retries in seconds (default from env: DB_RETRY_MAX_DELAY or 30.0) + force_refresh: Force connection pool refresh on stale connections (default from env: DB_RETRY_FORCE_REFRESH or True) + + Returns: + Decorated function with retry logic + + Example: + @db_retry(max_retries=5, base_delay=0.5, force_refresh=True) + def create_user(name, email): + return User.objects.create(name=name, email=email) + + # Using environment defaults + @db_retry() + def save_model(self): + self.save() + """ + # Get defaults from environment if not provided + defaults = get_default_retry_settings() + final_max_retries = ( + max_retries if max_retries is not None else defaults["max_retries"] + ) + final_base_delay = base_delay if base_delay is not None else defaults["base_delay"] + final_max_delay = max_delay if max_delay is not None else defaults["max_delay"] + final_force_refresh = ( + force_refresh if force_refresh is not None else defaults["force_refresh"] + ) + + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + def operation(): + return func(*args, **kwargs) + + return _execute_with_retry( + operation=operation, + max_retries=final_max_retries, + base_delay=final_base_delay, + max_delay=final_max_delay, + force_refresh=final_force_refresh, + ) + + return wrapper + + return decorator + + +def _log_context_non_retryable_error(error: BaseException) -> None: + """Log non-retryable errors in context manager.""" + logger.debug(LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=error)) + + +def _log_context_max_retries_exceeded(error: BaseException, total_retries: int) -> None: + """Log max retries exceeded in context manager.""" + logger.error( + LogMessages.format_message( + LogMessages.MAX_RETRIES_EXCEEDED, + total=total_retries, + error=error, + ) + ) + + +def _update_context_retry_settings( + retry_count: int, use_extended: bool, current_settings: dict +) -> dict: + """Update retry settings for extended retry if needed.""" + extended_settings = _update_retry_settings_for_extended(retry_count, use_extended) + if extended_settings: + return { + "max_retries": extended_settings["max_retries"], + "base_delay": extended_settings["base_delay"], + "max_delay": extended_settings["max_delay"], + "force_refresh": current_settings["force_refresh"], + } + return current_settings + + +def _handle_context_retry_logic( + retry_count: int, + final_max_retries: int, + final_base_delay: float, + final_max_delay: float, + final_force_refresh: bool, + error: BaseException, + error_type: DatabaseErrorType, + needs_refresh: bool, +) -> tuple[int, float]: + """Handle retry logic for context manager - returns updated retry_count and delay.""" + delay = min(final_base_delay * (2**retry_count), final_max_delay) + retry_count += 1 + + _handle_retry_attempt( + error, + error_type, + needs_refresh, + final_force_refresh, + retry_count, + final_max_retries, + delay, + ) + + return retry_count, delay + + +@contextmanager +def db_retry_context( + max_retries: int | None = None, + base_delay: float | None = None, + max_delay: float | None = None, + force_refresh: bool | None = None, +): + """Context manager for retrying database operations on connection errors. + + Uses environment variables DB_RETRY_MAX_RETRIES, DB_RETRY_BASE_DELAY, DB_RETRY_MAX_DELAY, + DB_RETRY_FORCE_REFRESH for defaults if parameters are not provided. + + Args: + max_retries: Maximum number of retry attempts (default from env: DB_RETRY_MAX_RETRIES or 3) + base_delay: Initial delay between retries in seconds (default from env: DB_RETRY_BASE_DELAY or 1.0) + max_delay: Maximum delay between retries in seconds (default from env: DB_RETRY_MAX_DELAY or 30.0) + force_refresh: Force connection pool refresh on stale connections (default from env: DB_RETRY_FORCE_REFRESH or True) + + Yields: + None + + Example: + with db_retry_context(max_retries=5, force_refresh=True): + model.save() + other_model.delete() + MyModel.objects.filter(...).update(...) + """ + # Get defaults from environment if not provided + defaults = get_default_retry_settings() + current_settings = { + "max_retries": max_retries + if max_retries is not None + else defaults["max_retries"], + "base_delay": base_delay if base_delay is not None else defaults["base_delay"], + "max_delay": max_delay if max_delay is not None else defaults["max_delay"], + "force_refresh": force_refresh + if force_refresh is not None + else defaults["force_refresh"], + } + + retry_count = 0 + + while retry_count <= current_settings["max_retries"]: + try: + yield + return # Success - exit the retry loop + except Exception as e: + error_type, needs_refresh, use_extended = _classify_database_error(e) + + if not DatabaseErrorPatterns.is_retryable_error(error_type): + _log_context_non_retryable_error(e) + raise + + # Update settings for extended retry if needed + current_settings = _update_context_retry_settings( + retry_count, use_extended, current_settings + ) + + if retry_count < current_settings["max_retries"]: + retry_count, delay = _handle_context_retry_logic( + retry_count, + current_settings["max_retries"], + current_settings["base_delay"], + current_settings["max_delay"], + current_settings["force_refresh"], + e, + error_type, + needs_refresh, + ) + time.sleep(delay) + continue + else: + _log_context_max_retries_exceeded(e, current_settings["max_retries"] + 1) + raise + + +def retry_database_operation( + operation: Callable, + max_retries: int | None = None, + base_delay: float | None = None, + max_delay: float | None = None, + force_refresh: bool | None = None, +) -> Any: + """Execute a database operation with retry logic. + + Uses environment variables DB_RETRY_MAX_RETRIES, DB_RETRY_BASE_DELAY, DB_RETRY_MAX_DELAY, + DB_RETRY_FORCE_REFRESH for defaults if parameters are not provided. + + Args: + operation: A callable that performs the database operation + max_retries: Maximum number of retry attempts (default from env: DB_RETRY_MAX_RETRIES or 3) + base_delay: Initial delay between retries in seconds (default from env: DB_RETRY_BASE_DELAY or 1.0) + max_delay: Maximum delay between retries in seconds (default from env: DB_RETRY_MAX_DELAY or 30.0) + force_refresh: Force connection pool refresh on stale connections (default from env: DB_RETRY_FORCE_REFRESH or True) + + Returns: + The result of the operation + + Example: + result = retry_database_operation( + lambda: MyModel.objects.filter(active=True).update(status='processed'), + max_retries=5, + force_refresh=True + ) + """ + # Get defaults from environment if not provided + defaults = get_default_retry_settings() + final_max_retries = ( + max_retries if max_retries is not None else defaults["max_retries"] + ) + final_base_delay = base_delay if base_delay is not None else defaults["base_delay"] + final_max_delay = max_delay if max_delay is not None else defaults["max_delay"] + final_force_refresh = ( + force_refresh if force_refresh is not None else defaults["force_refresh"] + ) + + return _execute_with_retry( + operation=operation, + max_retries=final_max_retries, + base_delay=final_base_delay, + max_delay=final_max_delay, + force_refresh=final_force_refresh, + ) + + +# Convenience function with default settings +def quick_retry(operation: Callable) -> Any: + """Execute a database operation with environment/default retry settings. + + Uses environment variables DB_RETRY_MAX_RETRIES, DB_RETRY_BASE_DELAY, DB_RETRY_MAX_DELAY + or built-in defaults (3 retries, 1.0s base delay, 30.0s max delay). + + Args: + operation: A callable that performs the database operation + + Returns: + The result of the operation + + Example: + result = quick_retry(lambda: model.save()) + """ + return retry_database_operation(operation) diff --git a/backend/utils/models/organization_mixin.py b/backend/utils/models/organization_mixin.py index 1e178fe0c4..a041fdb81d 100644 --- a/backend/utils/models/organization_mixin.py +++ b/backend/utils/models/organization_mixin.py @@ -1,6 +1,7 @@ # TODO:V2 class from account_v2.models import Organization from django.db import models +from utils.db_retry import db_retry from utils.user_context import UserContext @@ -17,6 +18,7 @@ class DefaultOrganizationMixin(models.Model): class Meta: abstract = True + @db_retry() # Add retry for connection drops during organization assignment def save(self, *args, **kwargs): if self.organization is None: self.organization = UserContext.get_organization() @@ -24,6 +26,7 @@ def save(self, *args, **kwargs): class DefaultOrganizationManagerMixin(models.Manager): + @db_retry() # Add retry for connection drops during queryset organization filtering def get_queryset(self): organization = UserContext.get_organization() return super().get_queryset().filter(organization=organization) diff --git a/backend/utils/user_context.py b/backend/utils/user_context.py index 71d63806b5..4b5c7dce54 100644 --- a/backend/utils/user_context.py +++ b/backend/utils/user_context.py @@ -2,6 +2,7 @@ from django.db.utils import ProgrammingError from utils.constants import Account +from utils.db_retry import db_retry from utils.local_context import StateStore @@ -16,6 +17,7 @@ def set_organization_identifier(organization_identifier: str) -> None: StateStore.set(Account.ORGANIZATION_ID, organization_identifier) @staticmethod + @db_retry() # Add retry for connection drops during organization lookup def get_organization() -> Organization | None: organization_id = StateStore.get(Account.ORGANIZATION_ID) try: diff --git a/backend/workflow_manager/file_execution/models.py b/backend/workflow_manager/file_execution/models.py index fd75b163fb..4ce6082d31 100644 --- a/backend/workflow_manager/file_execution/models.py +++ b/backend/workflow_manager/file_execution/models.py @@ -4,6 +4,7 @@ from django.db import models from utils.common_utils import CommonUtils +from utils.db_retry import db_retry from utils.models.base_model import BaseModel from workflow_manager.endpoint_v2.dto import FileHash @@ -16,6 +17,7 @@ class WorkflowFileExecutionManager(models.Manager): + @db_retry() # Use environment defaults for retry settings def get_or_create_file_execution( self, workflow_execution: Any, @@ -53,6 +55,7 @@ def get_or_create_file_execution( return execution_file + @db_retry() # Use environment defaults for retry settings def _update_execution_file( self, execution_file: "WorkflowFileExecution", file_hash: FileHash ) -> None: @@ -118,6 +121,7 @@ def __str__(self): f"(WorkflowExecution: {self.workflow_execution})" ) + @db_retry() # Use environment defaults for retry settings def update_status( self, status: ExecutionStatus, @@ -221,6 +225,7 @@ def is_completed(self) -> bool: """ return self.status is not None and self.status == ExecutionStatus.COMPLETED + @db_retry() # Use environment defaults for retry settings def update( self, file_hash: str = None, diff --git a/backend/workflow_manager/workflow_v2/execution.py b/backend/workflow_manager/workflow_v2/execution.py index 3de4cb6c10..2c0adab7e8 100644 --- a/backend/workflow_manager/workflow_v2/execution.py +++ b/backend/workflow_manager/workflow_v2/execution.py @@ -9,6 +9,7 @@ from tool_instance_v2.models import ToolInstance from tool_instance_v2.tool_processor import ToolProcessor from usage_v2.helper import UsageHelper +from utils.db_retry import db_retry from utils.local_context import StateStore from utils.user_context import UserContext @@ -391,6 +392,7 @@ def initiate_tool_execution( self.publish_log("Trying to fetch results from cache") @staticmethod + @db_retry() def update_execution_err(execution_id: str, err_msg: str = "") -> WorkflowExecution: try: execution = WorkflowExecution.objects.get(pk=execution_id) diff --git a/backend/workflow_manager/workflow_v2/models/execution.py b/backend/workflow_manager/workflow_v2/models/execution.py index 31c8988fe8..780190c164 100644 --- a/backend/workflow_manager/workflow_v2/models/execution.py +++ b/backend/workflow_manager/workflow_v2/models/execution.py @@ -12,6 +12,7 @@ from usage_v2.constants import UsageKeys from usage_v2.models import Usage from utils.common_utils import CommonUtils +from utils.db_retry import db_retry from utils.models.base_model import BaseModel from workflow_manager.execution.dto import ExecutionCache @@ -41,6 +42,7 @@ def for_user(self, user) -> QuerySet: # Return executions where the workflow's created_by matches the user return self.filter(workflow__created_by=user) + @db_retry() # Use environment defaults for retry settings def clean_invalid_workflows(self): """Remove execution records with invalid workflow references. @@ -237,6 +239,7 @@ def __str__(self) -> str: f"error message: {self.error_message})" ) + @db_retry() # Use environment defaults for retry settings def update_execution( self, status: ExecutionStatus | None = None, @@ -271,6 +274,7 @@ def update_execution( self.save() + @db_retry() # Use environment defaults for retry settings def update_execution_err(self, err_msg: str = "") -> None: """Update execution status to ERROR with an error message. @@ -279,6 +283,7 @@ def update_execution_err(self, err_msg: str = "") -> None: """ self.update_execution(status=ExecutionStatus.ERROR, error=err_msg) + @db_retry() # Use environment defaults for retry settings def _handle_execution_cache(self): if not ExecutionCacheUtils.is_execution_exists( workflow_id=self.workflow.id, execution_id=self.id