From 86cff591001162fc02de4a4be4af240aa86a32a1 Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 9 Sep 2025 17:36:48 +0530 Subject: [PATCH 1/6] UN-2798 [FIX] Fix WorkflowFileExecution stuck in EXECUTING status - Add centralized database retry mechanism with connection pool refresh - Implement retry logic for Django ORM and Celery SQLAlchemy operations - Apply @db_retry decorators to critical workflow chain components - Handle stale connections, database unavailability, and transient errors - Configure exponential backoff with customizable retry parameters --- backend/account_v2/organization.py | 3 + backend/backend/celery_config.py | 37 +- backend/backend/celery_db_retry.py | 369 ++++++++++++++ backend/backend/celery_service.py | 4 + backend/backend/settings/base.py | 12 + .../file_processing/file_processing.py | 4 + .../file_processing_callback.py | 4 + backend/sample.env | 12 + backend/usage_v2/helper.py | 3 + backend/utils/db_constants.py | 329 +++++++++++++ backend/utils/db_retry.py | 449 ++++++++++++++++++ backend/utils/models/organization_mixin.py | 3 + backend/utils/user_context.py | 2 + .../workflow_manager/file_execution/models.py | 5 + .../workflow_manager/workflow_v2/execution.py | 2 + .../workflow_v2/models/execution.py | 5 + 16 files changed, 1242 insertions(+), 1 deletion(-) create mode 100644 backend/backend/celery_db_retry.py create mode 100644 backend/utils/db_constants.py create mode 100644 backend/utils/db_retry.py diff --git a/backend/account_v2/organization.py b/backend/account_v2/organization.py index a8586924d5..eea5cd596c 100644 --- a/backend/account_v2/organization.py +++ b/backend/account_v2/organization.py @@ -1,6 +1,7 @@ import logging from django.db import IntegrityError +from utils.db_retry import db_retry from account_v2.models import Organization @@ -12,6 +13,7 @@ def __init__(self): # type: ignore pass @staticmethod + @db_retry() # Add retry for connection drops during organization lookup def get_organization_by_org_id(org_id: str) -> Organization | None: try: return Organization.objects.get(organization_id=org_id) # type: ignore @@ -19,6 +21,7 @@ def get_organization_by_org_id(org_id: str) -> Organization | None: return None @staticmethod + @db_retry() # Add retry for connection drops during organization creation def create_organization( name: str, display_name: str, organization_id: str ) -> Organization: diff --git a/backend/backend/celery_config.py b/backend/backend/celery_config.py index f8833556e7..3629af6553 100644 --- a/backend/backend/celery_config.py +++ b/backend/backend/celery_config.py @@ -1,10 +1,16 @@ +import os from urllib.parse import quote_plus from django.conf import settings +from backend.celery_db_retry import get_celery_db_engine_options, should_use_builtin_retry + class CeleryConfig: - """Specifies celery configuration + """Specifies celery configuration with hybrid retry support. + + Supports both custom retry (via patching) and Celery's built-in retry + based on CELERY_USE_BUILTIN_RETRY environment variable. Refer https://docs.celeryq.dev/en/stable/userguide/configuration.html """ @@ -31,3 +37,32 @@ class CeleryConfig: beat_scheduler = "django_celery_beat.schedulers:DatabaseScheduler" task_acks_late = True + + # Database backend engine options for PgBouncer compatibility + result_backend_transport_options = get_celery_db_engine_options() + + # Hybrid retry configuration - built-in vs custom + if should_use_builtin_retry(): + # Use Celery's built-in database backend retry + result_backend_always_retry = ( + os.environ.get("CELERY_RESULT_BACKEND_ALWAYS_RETRY", "true").lower() == "true" + ) + result_backend_max_retries = int( + os.environ.get("CELERY_RESULT_BACKEND_MAX_RETRIES", "3") + ) + result_backend_base_sleep_between_retries_ms = int( + os.environ.get("CELERY_RESULT_BACKEND_BASE_SLEEP_BETWEEN_RETRIES_MS", "1000") + ) + result_backend_max_sleep_between_retries_ms = int( + os.environ.get("CELERY_RESULT_BACKEND_MAX_SLEEP_BETWEEN_RETRIES_MS", "30000") + ) + + print( + f"[Celery Config] Using built-in retry: max_retries={result_backend_max_retries}, " + ) + print( + f"base_sleep={result_backend_base_sleep_between_retries_ms}ms, max_sleep={result_backend_max_sleep_between_retries_ms}ms" + ) + else: + # Custom retry is handled by patch_celery_database_backend() + print("[Celery Config] Using custom retry system (patching enabled)") diff --git a/backend/backend/celery_db_retry.py b/backend/backend/celery_db_retry.py new file mode 100644 index 0000000000..ded5ac1584 --- /dev/null +++ b/backend/backend/celery_db_retry.py @@ -0,0 +1,369 @@ +import logging +import time +from collections.abc import Callable +from functools import wraps +from typing import Any + +from utils.db_constants import ( + DatabaseErrorPatterns, + DatabaseErrorType, + LogMessages, + RetryConfiguration, +) + +logger = logging.getLogger(__name__) + + +def should_use_builtin_retry() -> bool: + """Check if we should use Celery's built-in retry instead of custom retry. + + Returns: + bool: True if built-in retry should be used, False for custom retry + """ + return RetryConfiguration.get_setting_value( + RetryConfiguration.ENV_CELERY_USE_BUILTIN, False + ) + + +def celery_db_retry_with_backoff( + max_retries: int | None = None, + base_delay: float | None = None, + max_delay: float | None = None, +): + """Decorator to retry Celery database backend operations with exponential backoff. + + This is specifically designed for Celery's database result backend operations + that may experience connection drops when using PgBouncer or database restarts. + + Args: + max_retries: Maximum number of retry attempts (defaults to settings or 3) + base_delay: Initial delay in seconds (defaults to settings or 1.0) + max_delay: Maximum delay in seconds (defaults to settings or 30.0) + """ + # Get defaults from centralized configuration + celery_settings = RetryConfiguration.get_celery_retry_settings() + + if max_retries is None: + max_retries = celery_settings["max_retries"] + if base_delay is None: + base_delay = celery_settings["base_delay"] + if max_delay is None: + max_delay = celery_settings["max_delay"] + + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + logger.debug( + LogMessages.format_message( + LogMessages.CELERY_OPERATION_START, operation=func.__name__ + ) + ) + retry_count = 0 + while retry_count <= max_retries: + try: + result = func(*args, **kwargs) + if retry_count == 0: + logger.debug(LogMessages.OPERATION_SUCCESS) + else: + logger.info( + LogMessages.format_message( + LogMessages.OPERATION_SUCCESS_AFTER_RETRY, + retry_count=retry_count, + ) + ) + return result + except Exception as e: + # Import here to avoid circular import + try: + from sqlalchemy.exc import ( + OperationalError as SQLAlchemyOperationalError, + ) + + is_sqlalchemy_error = isinstance(e, SQLAlchemyOperationalError) + except ImportError: + is_sqlalchemy_error = False + + if not is_sqlalchemy_error: + logger.debug( + LogMessages.format_message( + LogMessages.NON_RETRYABLE_ERROR, error=e + ) + ) + raise + + # Use centralized error classification + error_type, needs_refresh = DatabaseErrorPatterns.classify_error(e) + + if DatabaseErrorPatterns.is_retryable_error(error_type): + # For database unavailable errors, use extended settings if configured + current_max_retries = max_retries + current_base_delay = base_delay + current_max_delay = max_delay + + if ( + error_type == DatabaseErrorType.DATABASE_UNAVAILABLE + and retry_count == 0 + ): + extended_settings = RetryConfiguration.get_retry_settings( + use_extended=True + ) + current_max_retries = extended_settings["max_retries"] + current_base_delay = extended_settings["base_delay"] + current_max_delay = extended_settings["max_delay"] + + if retry_count < current_max_retries: + delay = min( + current_base_delay * (2**retry_count), current_max_delay + ) + retry_count += 1 + + # Handle SQLAlchemy connection pool disposal for severe connection issues + if needs_refresh: + logger.warning( + LogMessages.format_message( + LogMessages.POOL_CORRUPTION_DETECTED, + attempt=retry_count, + total=current_max_retries + 1, + error=e, + ) + ) + try: + _dispose_sqlalchemy_engine(func) + logger.info( + "SQLAlchemy connection pool disposed successfully" + ) + except Exception as refresh_error: + logger.warning( + LogMessages.format_message( + LogMessages.POOL_REFRESH_FAILED, + error=refresh_error, + ) + ) + else: + # Choose appropriate retry message based on error type + if error_type == DatabaseErrorType.DATABASE_UNAVAILABLE: + message = LogMessages.DATABASE_UNAVAILABLE_RETRY + else: + message = LogMessages.CONNECTION_ERROR_RETRY + + logger.warning( + LogMessages.format_message( + message, + attempt=retry_count, + total=current_max_retries + 1, + error=e, + delay=delay, + ) + ) + + time.sleep(delay) + continue + else: + logger.error( + LogMessages.format_message( + LogMessages.MAX_RETRIES_EXCEEDED, + total=current_max_retries + 1, + error=e, + ) + ) + raise + else: + # Not a connection error, re-raise immediately + logger.debug( + LogMessages.format_message( + LogMessages.NON_RETRYABLE_ERROR, error=e + ) + ) + raise + + # This should never be reached, but included for completeness + return func(*args, **kwargs) + + return wrapper + + return decorator + + +def _dispose_sqlalchemy_engine(func): + """Dispose SQLAlchemy engine to force connection pool recreation. + + This is called when we detect severe connection issues that require + the entire connection pool to be recreated. + """ + try: + # Try to get the engine from the Celery backend + from celery import current_app + + backend = current_app.backend + if hasattr(backend, "engine") and backend.engine: + backend.engine.dispose() + logger.info("Disposed SQLAlchemy engine for connection pool refresh") + except Exception as e: + logger.warning(f"Could not dispose SQLAlchemy engine: {e}") + + +# Track if patching has been applied to prevent double patching +_celery_backend_patched = False + + +def patch_celery_database_backend(): + """Dynamically patch Celery's database backend classes to add retry logic. + + This function should be called during application initialization to add + connection retry capabilities to Celery's database backend operations. + + Supports hybrid mode: can use either custom retry logic or Celery's built-in retry + based on CELERY_USE_BUILTIN_RETRY environment variable/setting. + """ + global _celery_backend_patched + + # Prevent double patching + if _celery_backend_patched: + logger.debug("Celery database backend already patched, skipping") + return + + # Check if we should use built-in retry instead of custom + if should_use_builtin_retry(): + logger.info( + "Using Celery's built-in database backend retry (CELERY_USE_BUILTIN_RETRY=true)" + ) + _configure_builtin_retry() + _celery_backend_patched = True + return + + logger.info("Using custom database backend retry (CELERY_USE_BUILTIN_RETRY=false)") + + try: + from celery.backends.database import DatabaseBackend + + # Check if already patched by looking for our marker attribute + if hasattr(DatabaseBackend._store_result, "_retry_patched"): + logger.debug( + "Celery database backend already patched (detected marker), skipping" + ) + _celery_backend_patched = True + return + + # Store original methods - bypass Celery's built-in retry by accessing the real methods + # This prevents double-retry conflicts when using custom retry + original_store_result = getattr( + DatabaseBackend._store_result, "__wrapped__", DatabaseBackend._store_result + ) + original_get_task_meta_for = getattr( + DatabaseBackend._get_task_meta_for, + "__wrapped__", + DatabaseBackend._get_task_meta_for, + ) + original_get_result = getattr( + DatabaseBackend.get_result, "__wrapped__", DatabaseBackend.get_result + ) + + logger.info( + "Bypassing Celery's built-in retry system to prevent conflicts (using custom retry)" + ) + + # Apply retry decorators + logger.info("Patching Celery method: _store_result with retry logic") + DatabaseBackend._store_result = celery_db_retry_with_backoff()( + original_store_result + ) + + logger.info("Patching Celery method: _get_task_meta_for with retry logic") + DatabaseBackend._get_task_meta_for = celery_db_retry_with_backoff()( + original_get_task_meta_for + ) + + logger.info("Patching Celery method: get_result with retry logic") + DatabaseBackend.get_result = celery_db_retry_with_backoff()(original_get_result) + + # delete_result may not exist in all Celery versions + if hasattr(DatabaseBackend, "delete_result"): + logger.info("Patching Celery method: delete_result with retry logic") + original_delete_result = getattr( + DatabaseBackend.delete_result, + "__wrapped__", + DatabaseBackend.delete_result, + ) + DatabaseBackend.delete_result = celery_db_retry_with_backoff()( + original_delete_result + ) + else: + logger.info("Celery method delete_result not found, skipping patch") + + # Mark as patched to prevent double patching + DatabaseBackend._store_result._retry_patched = True + _celery_backend_patched = True + + logger.info("Successfully patched Celery database backend with retry logic") + + except ImportError as e: + logger.warning(f"Could not import Celery database backend for patching: {e}") + except Exception as e: + logger.error(f"Error patching Celery database backend: {e}") + + +def _configure_builtin_retry(): + """Configure Celery's built-in database backend retry settings. + + This is called when CELERY_USE_BUILTIN_RETRY=true to setup Celery's + native retry functionality instead of our custom implementation. + """ + try: + import os + + # Get retry configuration from centralized settings + settings = RetryConfiguration.get_retry_settings() + + max_retries = settings["max_retries"] + base_delay_ms = int(settings["base_delay"] * 1000) + max_delay_ms = int(settings["max_delay"] * 1000) + + # Apply built-in retry settings to Celery configuration + logger.info( + f"Configured Celery built-in retry: max_retries={max_retries}, base_delay={base_delay_ms}ms, max_delay={max_delay_ms}ms" + ) + + # Store settings for use in celery_config.py + os.environ["CELERY_RESULT_BACKEND_ALWAYS_RETRY"] = "true" + os.environ["CELERY_RESULT_BACKEND_MAX_RETRIES"] = str(max_retries) + os.environ["CELERY_RESULT_BACKEND_BASE_SLEEP_BETWEEN_RETRIES_MS"] = str( + base_delay_ms + ) + os.environ["CELERY_RESULT_BACKEND_MAX_SLEEP_BETWEEN_RETRIES_MS"] = str( + max_delay_ms + ) + + logger.info("Successfully configured Celery's built-in database backend retry") + + except Exception as e: + logger.error(f"Error configuring Celery built-in retry: {e}") + logger.warning("Falling back to no retry mechanism") + + +def get_celery_db_engine_options(): + """Get SQLAlchemy engine options optimized for use with PgBouncer. + + Includes built-in retry configuration if CELERY_USE_BUILTIN_RETRY is enabled. + + These options are designed to work well with PgBouncer connection pooling + without interfering with PgBouncer's pool management. + + Returns: + dict: SQLAlchemy engine options + """ + return { + # Connection health checking + "pool_pre_ping": True, # Test connections before use + # Minimal pooling (let PgBouncer handle the real pooling) + "pool_size": 5, # Small pool since PgBouncer handles real pooling + "max_overflow": 0, # No overflow, rely on PgBouncer + "pool_recycle": 3600, # Recycle connections every hour + # Connection timeouts using centralized configuration + "connect_args": { + "connect_timeout": RetryConfiguration.get_setting_value( + "CELERY_DB_CONNECT_TIMEOUT", 30 + ), + }, + # Echo SQL queries if debug logging is enabled + "echo": RetryConfiguration.get_setting_value("CELERY_DB_ECHO_SQL", False), + } diff --git a/backend/backend/celery_service.py b/backend/backend/celery_service.py index 9b0e19b6f0..ec5ddc63d9 100644 --- a/backend/backend/celery_service.py +++ b/backend/backend/celery_service.py @@ -7,6 +7,7 @@ from celery import Celery +from backend.celery_db_retry import patch_celery_database_backend from backend.celery_task import TaskRegistry from backend.settings.base import LOGGING @@ -24,6 +25,9 @@ # Create a Celery instance. Default time zone is UTC. app = Celery("backend") +# Patch Celery database backend to add connection retry logic +patch_celery_database_backend() + # Register custom tasks TaskRegistry() diff --git a/backend/backend/settings/base.py b/backend/backend/settings/base.py index cb18de0ca5..7daf4c9639 100644 --- a/backend/backend/settings/base.py +++ b/backend/backend/settings/base.py @@ -115,6 +115,18 @@ def get_required_setting(setting_key: str, default: str | None = None) -> str | DB_PORT = os.environ.get("DB_PORT", 5432) DB_SCHEMA = os.environ.get("DB_SCHEMA", "unstract") +# Database connection retry settings +DB_CONNECTION_RETRY_COUNT = int(os.environ.get("DB_CONNECTION_RETRY_COUNT", "3")) +DB_CONNECTION_RETRY_DELAY = int(os.environ.get("DB_CONNECTION_RETRY_DELAY", "1")) +DB_CONNECTION_RETRY_MAX_DELAY = int(os.environ.get("DB_CONNECTION_RETRY_MAX_DELAY", "30")) + +# Celery database backend retry settings +CELERY_DB_RETRY_COUNT = int(os.environ.get("CELERY_DB_RETRY_COUNT", "3")) +CELERY_DB_RETRY_DELAY = int(os.environ.get("CELERY_DB_RETRY_DELAY", "1")) +CELERY_DB_RETRY_MAX_DELAY = int(os.environ.get("CELERY_DB_RETRY_MAX_DELAY", "30")) +CELERY_DB_CONNECT_TIMEOUT = int(os.environ.get("CELERY_DB_CONNECT_TIMEOUT", "30")) +CELERY_DB_ECHO_SQL = os.environ.get("CELERY_DB_ECHO_SQL", "False").lower() == "true" + # Celery Backend Database Name (falls back to main DB when unset or empty) CELERY_BACKEND_DB_NAME = os.environ.get("CELERY_BACKEND_DB_NAME") or DB_NAME DEFAULT_ORGANIZATION = "default_org" diff --git a/backend/backend/workers/file_processing/file_processing.py b/backend/backend/workers/file_processing/file_processing.py index b6c1f975e6..997e5c6792 100644 --- a/backend/backend/workers/file_processing/file_processing.py +++ b/backend/backend/workers/file_processing/file_processing.py @@ -7,6 +7,7 @@ from celery import Celery +from backend.celery_db_retry import patch_celery_database_backend from backend.settings.base import LOGGING from backend.workers.constants import CeleryWorkerNames from backend.workers.file_processing.celery_config import CeleryConfig @@ -25,6 +26,9 @@ # Create a Celery instance. Default time zone is UTC. app = Celery(CeleryWorkerNames.FILE_PROCESSING) +# Patch Celery database backend to add connection retry logic +patch_celery_database_backend() + # Load task modules from all registered Django app configs. app.config_from_object(CeleryConfig) diff --git a/backend/backend/workers/file_processing_callback/file_processing_callback.py b/backend/backend/workers/file_processing_callback/file_processing_callback.py index ecc20c6436..764e17fefd 100644 --- a/backend/backend/workers/file_processing_callback/file_processing_callback.py +++ b/backend/backend/workers/file_processing_callback/file_processing_callback.py @@ -7,6 +7,7 @@ from celery import Celery +from backend.celery_db_retry import patch_celery_database_backend from backend.settings.base import LOGGING from backend.workers.constants import CeleryWorkerNames from backend.workers.file_processing_callback.celery_config import CeleryConfig @@ -25,6 +26,9 @@ # Create a Celery instance. Default time zone is UTC. app = Celery(CeleryWorkerNames.FILE_PROCESSING_CALLBACK) +# Patch Celery database backend to add connection retry logic +patch_celery_database_backend() + # Load task modules from all registered Django app configs. app.config_from_object(CeleryConfig) diff --git a/backend/sample.env b/backend/sample.env index dcd79b5a16..8b87125bed 100644 --- a/backend/sample.env +++ b/backend/sample.env @@ -25,6 +25,18 @@ DB_SCHEMA="unstract" # Celery Backend Database (optional - defaults to DB_NAME if unset) # Example: # CELERY_BACKEND_DB_NAME=unstract_celery_db + +# Database connection retry settings (for handling connection drops with PgBouncer) +# DB_CONNECTION_RETRY_COUNT=3 +# DB_CONNECTION_RETRY_DELAY=1 +# DB_CONNECTION_RETRY_MAX_DELAY=30 + +# Celery database backend retry settings (for handling connection drops with PgBouncer) +# CELERY_DB_RETRY_COUNT=3 +# CELERY_DB_RETRY_DELAY=1 +# CELERY_DB_RETRY_MAX_DELAY=30 +# CELERY_DB_CONNECT_TIMEOUT=30 +# CELERY_DB_ECHO_SQL=False # Redis REDIS_HOST="unstract-redis" REDIS_PORT=6379 diff --git a/backend/usage_v2/helper.py b/backend/usage_v2/helper.py index 9452547acb..29e3d108ff 100644 --- a/backend/usage_v2/helper.py +++ b/backend/usage_v2/helper.py @@ -4,6 +4,7 @@ from django.db.models import QuerySet, Sum from rest_framework.exceptions import APIException +from utils.db_retry import db_retry from .constants import UsageKeys from .models import Usage @@ -13,6 +14,7 @@ class UsageHelper: @staticmethod + @db_retry() # Add retry for connection drops during usage aggregation def get_aggregated_token_count(run_id: str) -> dict: """Retrieve aggregated token counts for the given run_id. @@ -64,6 +66,7 @@ def get_aggregated_token_count(run_id: str) -> dict: raise APIException("Error while aggregating token counts") @staticmethod + @db_retry() # Add retry for connection drops during metrics aggregation def aggregate_usage_metrics(queryset: QuerySet) -> dict[str, Any]: """Aggregate usage metrics from a queryset of Usage objects. diff --git a/backend/utils/db_constants.py b/backend/utils/db_constants.py new file mode 100644 index 0000000000..0796a78f38 --- /dev/null +++ b/backend/utils/db_constants.py @@ -0,0 +1,329 @@ +"""Database Error Constants and Configuration + +Centralized error patterns, types, and configuration for database retry mechanisms. +Used by both Django ORM retry (db_retry.py) and Celery SQLAlchemy retry (celery_db_retry.py). +""" + +import os +from dataclasses import dataclass +from enum import Enum + + +class DatabaseErrorType(Enum): + """Classification of database connection errors.""" + + STALE_CONNECTION = "stale_connection" + DATABASE_UNAVAILABLE = "database_unavailable" + TRANSIENT_ERROR = "transient_error" + NON_CONNECTION_ERROR = "non_connection_error" + + +class ConnectionPoolType(Enum): + """Types of connection pools that need different refresh strategies.""" + + DJANGO_ORM = "django_orm" + SQLALCHEMY = "sqlalchemy" + + +@dataclass(frozen=True) +class ErrorPattern: + """Defines an error pattern with its classification and handling strategy.""" + + keywords: tuple[str, ...] + error_type: DatabaseErrorType + requires_pool_refresh: bool + description: str + + +class DatabaseErrorPatterns: + """Centralized database error patterns for consistent classification.""" + + # Stale connection patterns - connection exists but is dead + STALE_CONNECTION_PATTERNS = ( + ErrorPattern( + keywords=("connection already closed",), + error_type=DatabaseErrorType.STALE_CONNECTION, + requires_pool_refresh=True, + description="Django InterfaceError - connection pool corruption", + ), + ErrorPattern( + keywords=("server closed the connection unexpectedly",), + error_type=DatabaseErrorType.STALE_CONNECTION, + requires_pool_refresh=True, + description="PostgreSQL dropped connection, pool may be stale", + ), + ErrorPattern( + keywords=("connection was lost",), + error_type=DatabaseErrorType.STALE_CONNECTION, + requires_pool_refresh=True, + description="Connection lost during operation", + ), + ) + + # Database unavailable patterns - database is completely unreachable + DATABASE_UNAVAILABLE_PATTERNS = ( + ErrorPattern( + keywords=("connection refused",), + error_type=DatabaseErrorType.DATABASE_UNAVAILABLE, + requires_pool_refresh=True, + description="Database server is down or unreachable", + ), + ErrorPattern( + keywords=("could not connect",), + error_type=DatabaseErrorType.DATABASE_UNAVAILABLE, + requires_pool_refresh=True, + description="Unable to establish database connection", + ), + ErrorPattern( + keywords=("no route to host",), + error_type=DatabaseErrorType.DATABASE_UNAVAILABLE, + requires_pool_refresh=True, + description="Network routing issue to database", + ), + ) + + # Transient error patterns - temporary issues, retry without pool refresh + TRANSIENT_ERROR_PATTERNS = ( + ErrorPattern( + keywords=("connection timeout",), + error_type=DatabaseErrorType.TRANSIENT_ERROR, + requires_pool_refresh=False, + description="Connection attempt timed out", + ), + ErrorPattern( + keywords=("connection pool exhausted",), + error_type=DatabaseErrorType.TRANSIENT_ERROR, + requires_pool_refresh=False, + description="Connection pool temporarily full", + ), + ErrorPattern( + keywords=("connection closed",), + error_type=DatabaseErrorType.TRANSIENT_ERROR, + requires_pool_refresh=False, + description="Connection closed during operation", + ), + # MySQL compatibility patterns + ErrorPattern( + keywords=("lost connection to mysql server",), + error_type=DatabaseErrorType.TRANSIENT_ERROR, + requires_pool_refresh=False, + description="MySQL connection lost", + ), + ErrorPattern( + keywords=("mysql server has gone away",), + error_type=DatabaseErrorType.TRANSIENT_ERROR, + requires_pool_refresh=False, + description="MySQL server disconnected", + ), + ) + + # All patterns combined for easy iteration + ALL_PATTERNS = ( + STALE_CONNECTION_PATTERNS + + DATABASE_UNAVAILABLE_PATTERNS + + TRANSIENT_ERROR_PATTERNS + ) + + @classmethod + def classify_error( + cls, error: Exception, error_message: str = None + ) -> tuple[DatabaseErrorType, bool]: + """Classify a database error and determine if pool refresh is needed. + + Args: + error: The exception to classify + error_message: Optional pre-lowercased error message for efficiency + + Returns: + Tuple[DatabaseErrorType, bool]: (error_type, needs_pool_refresh) + """ + if error_message is None: + error_message = str(error).lower() + else: + error_message = error_message.lower() + + # Check all patterns for matches + for pattern in cls.ALL_PATTERNS: + if any(keyword in error_message for keyword in pattern.keywords): + return pattern.error_type, pattern.requires_pool_refresh + + # No pattern matched + return DatabaseErrorType.NON_CONNECTION_ERROR, False + + @classmethod + def is_retryable_error(cls, error_type: DatabaseErrorType) -> bool: + """Check if an error type should be retried.""" + return error_type in { + DatabaseErrorType.STALE_CONNECTION, + DatabaseErrorType.DATABASE_UNAVAILABLE, + DatabaseErrorType.TRANSIENT_ERROR, + } + + @classmethod + def get_all_error_keywords(cls) -> list[str]: + """Get all error keywords as a flat list (for backward compatibility).""" + keywords = [] + for pattern in cls.ALL_PATTERNS: + keywords.extend(pattern.keywords) + return keywords + + +class RetryConfiguration: + """Centralized configuration for database retry settings.""" + + # Environment variable names + ENV_MAX_RETRIES = "DB_RETRY_MAX_RETRIES" + ENV_BASE_DELAY = "DB_RETRY_BASE_DELAY" + ENV_MAX_DELAY = "DB_RETRY_MAX_DELAY" + ENV_FORCE_REFRESH = "DB_RETRY_FORCE_REFRESH" + + # Celery-specific environment variables + ENV_CELERY_USE_BUILTIN = "CELERY_USE_BUILTIN_RETRY" + ENV_CELERY_MAX_RETRIES = "CELERY_DB_RETRY_COUNT" + ENV_CELERY_BASE_DELAY = "CELERY_DB_RETRY_DELAY" + ENV_CELERY_MAX_DELAY = "CELERY_DB_RETRY_MAX_DELAY" + + # Default values + DEFAULT_MAX_RETRIES = 3 + DEFAULT_BASE_DELAY = 1.0 + DEFAULT_MAX_DELAY = 30.0 + DEFAULT_FORCE_REFRESH = True + + # Extended retry for database unavailable scenarios + DEFAULT_EXTENDED_MAX_RETRIES = 8 + DEFAULT_EXTENDED_BASE_DELAY = 2.0 + DEFAULT_EXTENDED_MAX_DELAY = 60.0 + + @classmethod + def get_setting_value(cls, setting_name: str, default_value, use_django=True): + """Get setting value from Django settings or environment with proper type conversion. + + Args: + setting_name: Name of the setting + default_value: Default value (determines return type) + use_django: Whether to try Django settings first + + Returns: + Setting value converted to same type as default_value + """ + # Try Django settings first (if available and requested) + if use_django: + try: + from django.conf import settings + from django.core.exceptions import ImproperlyConfigured + + if hasattr(settings, setting_name): + return getattr(settings, setting_name) + except (ImportError, ImproperlyConfigured): + # Django not available or not configured + pass + + # Fall back to environment variables + env_value = os.environ.get(setting_name) + if env_value is not None: + return cls._convert_env_value(env_value, default_value) + + return default_value + + @classmethod + def _convert_env_value(cls, env_value: str, default_value): + """Convert environment variable string to appropriate type.""" + try: + # Check bool first since bool is subclass of int in Python + if isinstance(default_value, bool): + return env_value.lower() in ("true", "1", "yes", "on") + elif isinstance(default_value, int): + return int(env_value) + elif isinstance(default_value, float): + return float(env_value) + else: + return env_value + except (ValueError, TypeError): + return default_value + + @classmethod + def get_retry_settings(cls, use_extended=False, use_django=True) -> dict: + """Get complete retry configuration. + + Args: + use_extended: Use extended settings for database unavailable scenarios + use_django: Whether to check Django settings + + Returns: + Dict with retry configuration + """ + if use_extended: + max_retries_default = cls.DEFAULT_EXTENDED_MAX_RETRIES + base_delay_default = cls.DEFAULT_EXTENDED_BASE_DELAY + max_delay_default = cls.DEFAULT_EXTENDED_MAX_DELAY + else: + max_retries_default = cls.DEFAULT_MAX_RETRIES + base_delay_default = cls.DEFAULT_BASE_DELAY + max_delay_default = cls.DEFAULT_MAX_DELAY + + return { + "max_retries": cls.get_setting_value( + cls.ENV_MAX_RETRIES, max_retries_default, use_django + ), + "base_delay": cls.get_setting_value( + cls.ENV_BASE_DELAY, base_delay_default, use_django + ), + "max_delay": cls.get_setting_value( + cls.ENV_MAX_DELAY, max_delay_default, use_django + ), + "force_refresh": cls.get_setting_value( + cls.ENV_FORCE_REFRESH, cls.DEFAULT_FORCE_REFRESH, use_django + ), + } + + @classmethod + def get_celery_retry_settings(cls, use_django=True) -> dict: + """Get Celery-specific retry configuration.""" + return { + "max_retries": cls.get_setting_value( + cls.ENV_CELERY_MAX_RETRIES, cls.DEFAULT_MAX_RETRIES, use_django + ), + "base_delay": cls.get_setting_value( + cls.ENV_CELERY_BASE_DELAY, cls.DEFAULT_BASE_DELAY, use_django + ), + "max_delay": cls.get_setting_value( + cls.ENV_CELERY_MAX_DELAY, cls.DEFAULT_MAX_DELAY, use_django + ), + "use_builtin": cls.get_setting_value( + cls.ENV_CELERY_USE_BUILTIN, False, use_django + ), + } + + +class LogMessages: + """Centralized log message templates for consistent logging.""" + + # Success messages + OPERATION_SUCCESS = "Operation completed successfully" + OPERATION_SUCCESS_AFTER_RETRY = ( + "Operation succeeded after {retry_count} retry attempts" + ) + + # Connection pool refresh messages + POOL_CORRUPTION_DETECTED = "Database connection pool corruption detected (attempt {attempt}/{total}): {error}. Refreshing connection pool..." + POOL_REFRESH_SUCCESS = "Connection pool refreshed successfully" + POOL_REFRESH_FAILED = "Failed to refresh connection pool: {error}" + + # Retry attempt messages + CONNECTION_ERROR_RETRY = "Database connection error (attempt {attempt}/{total}): {error}. Retrying in {delay} seconds..." + DATABASE_UNAVAILABLE_RETRY = "Database unavailable (attempt {attempt}/{total}): {error}. Using extended retry in {delay} seconds..." + + # Failure messages + MAX_RETRIES_EXCEEDED = "Database connection failed after {total} attempts: {error}" + NON_RETRYABLE_ERROR = "Non-retryable error, not retrying: {error}" + + # Execution messages + EXECUTING_WITH_RETRY = "Executing operation with retry logic..." + CELERY_OPERATION_START = ( + "Executing Celery DB operation: {operation} (retry mechanism active)" + ) + + @classmethod + def format_message(cls, template: str, **kwargs) -> str: + """Format a log message template with provided kwargs.""" + return template.format(**kwargs) diff --git a/backend/utils/db_retry.py b/backend/utils/db_retry.py new file mode 100644 index 0000000000..752284e2f1 --- /dev/null +++ b/backend/utils/db_retry.py @@ -0,0 +1,449 @@ +"""Database Retry Utility + +A simple, focused utility for retrying Django ORM operations that fail due to +connection issues. Designed for use in models, views, services, and anywhere +database operations might encounter connection drops. + +Features: +- Automatic connection pool refresh for stale connections +- Extended retry logic for database unavailable scenarios +- Centralized error classification and handling +- Environment variable configuration +- Multiple usage patterns: decorators, context managers, direct calls + +Usage Examples: + + # Decorator usage + @db_retry(max_retries=3) + def my_database_operation(): + Model.objects.create(...) + + # Context manager usage + with db_retry_context(max_retries=3): + model.save() + other_model.delete() + + # Direct function usage + result = retry_database_operation( + lambda: MyModel.objects.filter(...).update(...), + max_retries=3 + ) +""" + +import logging +import time +from collections.abc import Callable +from contextlib import contextmanager +from functools import wraps +from typing import Any + +from django.db import close_old_connections +from django.db.utils import InterfaceError, OperationalError + +from utils.db_constants import ( + DatabaseErrorPatterns, + DatabaseErrorType, + LogMessages, + RetryConfiguration, +) + +logger = logging.getLogger(__name__) + + +# Legacy function for backward compatibility - now delegates to RetryConfiguration +def get_retry_setting(setting_name: str, default_value): + """Get retry setting from Django settings or environment variables with fallback to default. + + Args: + setting_name: The setting name to look for + default_value: Default value if setting is not found + + Returns: + The setting value (int or float based on default_value type) + """ + return RetryConfiguration.get_setting_value(setting_name, default_value) + + +def get_default_retry_settings(use_extended=False): + """Get default retry settings from environment or use built-in defaults. + + Environment variables: + DB_RETRY_MAX_RETRIES: Maximum number of retry attempts (default: 3 or 8 if extended) + DB_RETRY_BASE_DELAY: Initial delay between retries in seconds (default: 1.0 or 2.0 if extended) + DB_RETRY_MAX_DELAY: Maximum delay between retries in seconds (default: 30.0 or 60.0 if extended) + DB_RETRY_FORCE_REFRESH: Force connection pool refresh on stale connections (default: True) + + Args: + use_extended: Use extended retry settings for database unavailable scenarios + + Returns: + dict: Dictionary with retry settings + """ + return RetryConfiguration.get_retry_settings(use_extended=use_extended) + + +def _classify_database_error(error: Exception): + """Classify a database error using centralized error patterns. + + Args: + error: The exception to classify + + Returns: + Tuple containing (error_type, needs_refresh, use_extended_retry) + """ + # Only classify Django database errors + if not isinstance(error, (OperationalError, InterfaceError)): + return DatabaseErrorType.NON_CONNECTION_ERROR, False, False + + error_type, needs_refresh = DatabaseErrorPatterns.classify_error(error) + + # For InterfaceError with stale connection patterns, always refresh + if ( + isinstance(error, InterfaceError) + and error_type == DatabaseErrorType.STALE_CONNECTION + ): + needs_refresh = True + + # Use extended retry for database unavailable scenarios + use_extended_retry = error_type == DatabaseErrorType.DATABASE_UNAVAILABLE + + return error_type, needs_refresh, use_extended_retry + + +def _execute_with_retry( + operation: Callable, + max_retries: int = 3, + base_delay: float = 1.0, + max_delay: float = 30.0, + force_refresh: bool = True, +) -> Any: + """Execute an operation with retry logic for connection errors. + + Args: + operation: The operation to execute + max_retries: Maximum number of retry attempts + base_delay: Initial delay between retries in seconds + max_delay: Maximum delay between retries in seconds + force_refresh: Whether to force connection pool refresh on stale connections + + Returns: + The result of the operation + + Raises: + The original exception if max retries exceeded or non-connection error + """ + retry_count = 0 + + while retry_count <= max_retries: + try: + logger.debug(LogMessages.EXECUTING_WITH_RETRY) + return operation() + except Exception as e: + error_type, needs_refresh, use_extended = _classify_database_error(e) + + if not DatabaseErrorPatterns.is_retryable_error(error_type): + logger.debug( + LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=e) + ) + raise + + # For database unavailable errors, use extended settings if configured + if use_extended and retry_count == 0: + extended_settings = get_default_retry_settings(use_extended=True) + max_retries = extended_settings["max_retries"] + base_delay = extended_settings["base_delay"] + max_delay = extended_settings["max_delay"] + + if retry_count < max_retries: + delay = min(base_delay * (2**retry_count), max_delay) + retry_count += 1 + + # Handle connection pool refresh for stale connections + if needs_refresh and force_refresh: + logger.warning( + LogMessages.format_message( + LogMessages.POOL_CORRUPTION_DETECTED, + attempt=retry_count, + total=max_retries + 1, + error=e, + ) + ) + try: + close_old_connections() + logger.info(LogMessages.POOL_REFRESH_SUCCESS) + except Exception as refresh_error: + logger.warning( + LogMessages.format_message( + LogMessages.POOL_REFRESH_FAILED, error=refresh_error + ) + ) + else: + # Choose appropriate retry message based on error type + if error_type == DatabaseErrorType.DATABASE_UNAVAILABLE: + message = LogMessages.DATABASE_UNAVAILABLE_RETRY + else: + message = LogMessages.CONNECTION_ERROR_RETRY + + logger.warning( + LogMessages.format_message( + message, + attempt=retry_count, + total=max_retries + 1, + error=e, + delay=delay, + ) + ) + + time.sleep(delay) + continue + else: + logger.error( + LogMessages.format_message( + LogMessages.MAX_RETRIES_EXCEEDED, total=max_retries + 1, error=e + ) + ) + raise + + # This should never be reached, but included for completeness + return operation() + + +def db_retry( + max_retries: int | None = None, + base_delay: float | None = None, + max_delay: float | None = None, + force_refresh: bool | None = None, +) -> Callable: + """Decorator to retry database operations on connection errors. + + Uses environment variables DB_RETRY_MAX_RETRIES, DB_RETRY_BASE_DELAY, DB_RETRY_MAX_DELAY, + DB_RETRY_FORCE_REFRESH for defaults if parameters are not provided. + + Args: + max_retries: Maximum number of retry attempts (default from env: DB_RETRY_MAX_RETRIES or 3) + base_delay: Initial delay between retries in seconds (default from env: DB_RETRY_BASE_DELAY or 1.0) + max_delay: Maximum delay between retries in seconds (default from env: DB_RETRY_MAX_DELAY or 30.0) + force_refresh: Force connection pool refresh on stale connections (default from env: DB_RETRY_FORCE_REFRESH or True) + + Returns: + Decorated function with retry logic + + Example: + @db_retry(max_retries=5, base_delay=0.5, force_refresh=True) + def create_user(name, email): + return User.objects.create(name=name, email=email) + + # Using environment defaults + @db_retry() + def save_model(self): + self.save() + """ + # Get defaults from environment if not provided + defaults = get_default_retry_settings() + final_max_retries = ( + max_retries if max_retries is not None else defaults["max_retries"] + ) + final_base_delay = base_delay if base_delay is not None else defaults["base_delay"] + final_max_delay = max_delay if max_delay is not None else defaults["max_delay"] + final_force_refresh = ( + force_refresh if force_refresh is not None else defaults["force_refresh"] + ) + + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs) -> Any: + def operation(): + return func(*args, **kwargs) + + return _execute_with_retry( + operation=operation, + max_retries=final_max_retries, + base_delay=final_base_delay, + max_delay=final_max_delay, + force_refresh=final_force_refresh, + ) + + return wrapper + + return decorator + + +@contextmanager +def db_retry_context( + max_retries: int | None = None, + base_delay: float | None = None, + max_delay: float | None = None, + force_refresh: bool | None = None, +): + """Context manager for retrying database operations on connection errors. + + Uses environment variables DB_RETRY_MAX_RETRIES, DB_RETRY_BASE_DELAY, DB_RETRY_MAX_DELAY, + DB_RETRY_FORCE_REFRESH for defaults if parameters are not provided. + + Args: + max_retries: Maximum number of retry attempts (default from env: DB_RETRY_MAX_RETRIES or 3) + base_delay: Initial delay between retries in seconds (default from env: DB_RETRY_BASE_DELAY or 1.0) + max_delay: Maximum delay between retries in seconds (default from env: DB_RETRY_MAX_DELAY or 30.0) + force_refresh: Force connection pool refresh on stale connections (default from env: DB_RETRY_FORCE_REFRESH or True) + + Yields: + None + + Example: + with db_retry_context(max_retries=5, force_refresh=True): + model.save() + other_model.delete() + MyModel.objects.filter(...).update(...) + """ + # Get defaults from environment if not provided + defaults = get_default_retry_settings() + final_max_retries = ( + max_retries if max_retries is not None else defaults["max_retries"] + ) + final_base_delay = base_delay if base_delay is not None else defaults["base_delay"] + final_max_delay = max_delay if max_delay is not None else defaults["max_delay"] + final_force_refresh = ( + force_refresh if force_refresh is not None else defaults["force_refresh"] + ) + + retry_count = 0 + + while retry_count <= final_max_retries: + try: + yield + return # Success - exit the retry loop + except Exception as e: + error_type, needs_refresh, use_extended = _classify_database_error(e) + + if not DatabaseErrorPatterns.is_retryable_error(error_type): + logger.debug( + LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=e) + ) + raise + + # For database unavailable errors, use extended settings if configured + if use_extended and retry_count == 0: + extended_settings = get_default_retry_settings(use_extended=True) + final_max_retries = extended_settings["max_retries"] + final_base_delay = extended_settings["base_delay"] + final_max_delay = extended_settings["max_delay"] + + if retry_count < final_max_retries: + delay = min(final_base_delay * (2**retry_count), final_max_delay) + retry_count += 1 + + # Handle connection pool refresh for stale connections + if needs_refresh and final_force_refresh: + logger.warning( + LogMessages.format_message( + LogMessages.POOL_CORRUPTION_DETECTED, + attempt=retry_count, + total=final_max_retries + 1, + error=e, + ) + ) + try: + close_old_connections() + logger.info(LogMessages.POOL_REFRESH_SUCCESS) + except Exception as refresh_error: + logger.warning( + LogMessages.format_message( + LogMessages.POOL_REFRESH_FAILED, error=refresh_error + ) + ) + else: + # Choose appropriate retry message based on error type + if error_type == DatabaseErrorType.DATABASE_UNAVAILABLE: + message = LogMessages.DATABASE_UNAVAILABLE_RETRY + else: + message = LogMessages.CONNECTION_ERROR_RETRY + + logger.warning( + LogMessages.format_message( + message, + attempt=retry_count, + total=final_max_retries + 1, + error=e, + delay=delay, + ) + ) + + time.sleep(delay) + continue + else: + logger.error( + LogMessages.format_message( + LogMessages.MAX_RETRIES_EXCEEDED, + total=final_max_retries + 1, + error=e, + ) + ) + raise + + +def retry_database_operation( + operation: Callable, + max_retries: int | None = None, + base_delay: float | None = None, + max_delay: float | None = None, + force_refresh: bool | None = None, +) -> Any: + """Execute a database operation with retry logic. + + Uses environment variables DB_RETRY_MAX_RETRIES, DB_RETRY_BASE_DELAY, DB_RETRY_MAX_DELAY, + DB_RETRY_FORCE_REFRESH for defaults if parameters are not provided. + + Args: + operation: A callable that performs the database operation + max_retries: Maximum number of retry attempts (default from env: DB_RETRY_MAX_RETRIES or 3) + base_delay: Initial delay between retries in seconds (default from env: DB_RETRY_BASE_DELAY or 1.0) + max_delay: Maximum delay between retries in seconds (default from env: DB_RETRY_MAX_DELAY or 30.0) + force_refresh: Force connection pool refresh on stale connections (default from env: DB_RETRY_FORCE_REFRESH or True) + + Returns: + The result of the operation + + Example: + result = retry_database_operation( + lambda: MyModel.objects.filter(active=True).update(status='processed'), + max_retries=5, + force_refresh=True + ) + """ + # Get defaults from environment if not provided + defaults = get_default_retry_settings() + final_max_retries = ( + max_retries if max_retries is not None else defaults["max_retries"] + ) + final_base_delay = base_delay if base_delay is not None else defaults["base_delay"] + final_max_delay = max_delay if max_delay is not None else defaults["max_delay"] + final_force_refresh = ( + force_refresh if force_refresh is not None else defaults["force_refresh"] + ) + + return _execute_with_retry( + operation=operation, + max_retries=final_max_retries, + base_delay=final_base_delay, + max_delay=final_max_delay, + force_refresh=final_force_refresh, + ) + + +# Convenience function with default settings +def quick_retry(operation: Callable) -> Any: + """Execute a database operation with environment/default retry settings. + + Uses environment variables DB_RETRY_MAX_RETRIES, DB_RETRY_BASE_DELAY, DB_RETRY_MAX_DELAY + or built-in defaults (3 retries, 1.0s base delay, 30.0s max delay). + + Args: + operation: A callable that performs the database operation + + Returns: + The result of the operation + + Example: + result = quick_retry(lambda: model.save()) + """ + return retry_database_operation(operation) diff --git a/backend/utils/models/organization_mixin.py b/backend/utils/models/organization_mixin.py index 1e178fe0c4..a041fdb81d 100644 --- a/backend/utils/models/organization_mixin.py +++ b/backend/utils/models/organization_mixin.py @@ -1,6 +1,7 @@ # TODO:V2 class from account_v2.models import Organization from django.db import models +from utils.db_retry import db_retry from utils.user_context import UserContext @@ -17,6 +18,7 @@ class DefaultOrganizationMixin(models.Model): class Meta: abstract = True + @db_retry() # Add retry for connection drops during organization assignment def save(self, *args, **kwargs): if self.organization is None: self.organization = UserContext.get_organization() @@ -24,6 +26,7 @@ def save(self, *args, **kwargs): class DefaultOrganizationManagerMixin(models.Manager): + @db_retry() # Add retry for connection drops during queryset organization filtering def get_queryset(self): organization = UserContext.get_organization() return super().get_queryset().filter(organization=organization) diff --git a/backend/utils/user_context.py b/backend/utils/user_context.py index 71d63806b5..4b5c7dce54 100644 --- a/backend/utils/user_context.py +++ b/backend/utils/user_context.py @@ -2,6 +2,7 @@ from django.db.utils import ProgrammingError from utils.constants import Account +from utils.db_retry import db_retry from utils.local_context import StateStore @@ -16,6 +17,7 @@ def set_organization_identifier(organization_identifier: str) -> None: StateStore.set(Account.ORGANIZATION_ID, organization_identifier) @staticmethod + @db_retry() # Add retry for connection drops during organization lookup def get_organization() -> Organization | None: organization_id = StateStore.get(Account.ORGANIZATION_ID) try: diff --git a/backend/workflow_manager/file_execution/models.py b/backend/workflow_manager/file_execution/models.py index fd75b163fb..4ce6082d31 100644 --- a/backend/workflow_manager/file_execution/models.py +++ b/backend/workflow_manager/file_execution/models.py @@ -4,6 +4,7 @@ from django.db import models from utils.common_utils import CommonUtils +from utils.db_retry import db_retry from utils.models.base_model import BaseModel from workflow_manager.endpoint_v2.dto import FileHash @@ -16,6 +17,7 @@ class WorkflowFileExecutionManager(models.Manager): + @db_retry() # Use environment defaults for retry settings def get_or_create_file_execution( self, workflow_execution: Any, @@ -53,6 +55,7 @@ def get_or_create_file_execution( return execution_file + @db_retry() # Use environment defaults for retry settings def _update_execution_file( self, execution_file: "WorkflowFileExecution", file_hash: FileHash ) -> None: @@ -118,6 +121,7 @@ def __str__(self): f"(WorkflowExecution: {self.workflow_execution})" ) + @db_retry() # Use environment defaults for retry settings def update_status( self, status: ExecutionStatus, @@ -221,6 +225,7 @@ def is_completed(self) -> bool: """ return self.status is not None and self.status == ExecutionStatus.COMPLETED + @db_retry() # Use environment defaults for retry settings def update( self, file_hash: str = None, diff --git a/backend/workflow_manager/workflow_v2/execution.py b/backend/workflow_manager/workflow_v2/execution.py index 3de4cb6c10..2c0adab7e8 100644 --- a/backend/workflow_manager/workflow_v2/execution.py +++ b/backend/workflow_manager/workflow_v2/execution.py @@ -9,6 +9,7 @@ from tool_instance_v2.models import ToolInstance from tool_instance_v2.tool_processor import ToolProcessor from usage_v2.helper import UsageHelper +from utils.db_retry import db_retry from utils.local_context import StateStore from utils.user_context import UserContext @@ -391,6 +392,7 @@ def initiate_tool_execution( self.publish_log("Trying to fetch results from cache") @staticmethod + @db_retry() def update_execution_err(execution_id: str, err_msg: str = "") -> WorkflowExecution: try: execution = WorkflowExecution.objects.get(pk=execution_id) diff --git a/backend/workflow_manager/workflow_v2/models/execution.py b/backend/workflow_manager/workflow_v2/models/execution.py index 31c8988fe8..780190c164 100644 --- a/backend/workflow_manager/workflow_v2/models/execution.py +++ b/backend/workflow_manager/workflow_v2/models/execution.py @@ -12,6 +12,7 @@ from usage_v2.constants import UsageKeys from usage_v2.models import Usage from utils.common_utils import CommonUtils +from utils.db_retry import db_retry from utils.models.base_model import BaseModel from workflow_manager.execution.dto import ExecutionCache @@ -41,6 +42,7 @@ def for_user(self, user) -> QuerySet: # Return executions where the workflow's created_by matches the user return self.filter(workflow__created_by=user) + @db_retry() # Use environment defaults for retry settings def clean_invalid_workflows(self): """Remove execution records with invalid workflow references. @@ -237,6 +239,7 @@ def __str__(self) -> str: f"error message: {self.error_message})" ) + @db_retry() # Use environment defaults for retry settings def update_execution( self, status: ExecutionStatus | None = None, @@ -271,6 +274,7 @@ def update_execution( self.save() + @db_retry() # Use environment defaults for retry settings def update_execution_err(self, err_msg: str = "") -> None: """Update execution status to ERROR with an error message. @@ -279,6 +283,7 @@ def update_execution_err(self, err_msg: str = "") -> None: """ self.update_execution(status=ExecutionStatus.ERROR, error=err_msg) + @db_retry() # Use environment defaults for retry settings def _handle_execution_cache(self): if not ExecutionCacheUtils.is_execution_exists( workflow_id=self.workflow.id, execution_id=self.id From f0979aa716bfae99979f94f1547736ed3f3274ad Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 9 Sep 2025 18:24:18 +0530 Subject: [PATCH 2/6] refactor to reduce complexity --- backend/backend/celery_db_retry.py | 223 +++++++++++++++++------------ backend/utils/db_retry.py | 206 +++++++++++++++----------- 2 files changed, 257 insertions(+), 172 deletions(-) diff --git a/backend/backend/celery_db_retry.py b/backend/backend/celery_db_retry.py index ded5ac1584..79dd7d9657 100644 --- a/backend/backend/celery_db_retry.py +++ b/backend/backend/celery_db_retry.py @@ -25,6 +25,99 @@ def should_use_builtin_retry() -> bool: ) +def _is_sqlalchemy_error(exception: Exception) -> bool: + """Check if exception is a SQLAlchemy error.""" + try: + from sqlalchemy.exc import OperationalError as SQLAlchemyOperationalError + + return isinstance(exception, SQLAlchemyOperationalError) + except ImportError: + return False + + +def _get_retry_settings_for_error( + error_type: DatabaseErrorType, + retry_count: int, + max_retries: int, + base_delay: float, + max_delay: float, +) -> dict: + """Get appropriate retry settings based on error type.""" + if error_type == DatabaseErrorType.DATABASE_UNAVAILABLE and retry_count == 0: + extended_settings = RetryConfiguration.get_retry_settings(use_extended=True) + return { + "max_retries": extended_settings["max_retries"], + "base_delay": extended_settings["base_delay"], + "max_delay": extended_settings["max_delay"], + } + return { + "max_retries": max_retries, + "base_delay": base_delay, + "max_delay": max_delay, + } + + +def _handle_pool_refresh( + func: Callable, error: Exception, retry_count: int, total_retries: int +) -> None: + """Handle SQLAlchemy connection pool disposal.""" + logger.warning( + LogMessages.format_message( + LogMessages.POOL_CORRUPTION_DETECTED, + attempt=retry_count, + total=total_retries, + error=error, + ) + ) + try: + _dispose_sqlalchemy_engine() + logger.info("SQLAlchemy connection pool disposed successfully") + except Exception as refresh_error: + logger.warning( + LogMessages.format_message( + LogMessages.POOL_REFRESH_FAILED, + error=refresh_error, + ) + ) + + +def _log_retry_attempt( + error_type: DatabaseErrorType, + retry_count: int, + total_retries: int, + error: Exception, + delay: float, +) -> None: + """Log retry attempt with appropriate message.""" + if error_type == DatabaseErrorType.DATABASE_UNAVAILABLE: + message = LogMessages.DATABASE_UNAVAILABLE_RETRY + else: + message = LogMessages.CONNECTION_ERROR_RETRY + + logger.warning( + LogMessages.format_message( + message, + attempt=retry_count, + total=total_retries, + error=error, + delay=delay, + ) + ) + + +def _log_success(retry_count: int) -> None: + """Log operation success.""" + if retry_count == 0: + logger.debug(LogMessages.OPERATION_SUCCESS) + else: + logger.info( + LogMessages.format_message( + LogMessages.OPERATION_SUCCESS_AFTER_RETRY, + retry_count=retry_count, + ) + ) + + def celery_db_retry_with_backoff( max_retries: int | None = None, base_delay: float | None = None, @@ -62,28 +155,10 @@ def wrapper(*args, **kwargs) -> Any: while retry_count <= max_retries: try: result = func(*args, **kwargs) - if retry_count == 0: - logger.debug(LogMessages.OPERATION_SUCCESS) - else: - logger.info( - LogMessages.format_message( - LogMessages.OPERATION_SUCCESS_AFTER_RETRY, - retry_count=retry_count, - ) - ) + _log_success(retry_count) return result except Exception as e: - # Import here to avoid circular import - try: - from sqlalchemy.exc import ( - OperationalError as SQLAlchemyOperationalError, - ) - - is_sqlalchemy_error = isinstance(e, SQLAlchemyOperationalError) - except ImportError: - is_sqlalchemy_error = False - - if not is_sqlalchemy_error: + if not _is_sqlalchemy_error(e): logger.debug( LogMessages.format_message( LogMessages.NON_RETRYABLE_ERROR, error=e @@ -94,84 +169,48 @@ def wrapper(*args, **kwargs) -> Any: # Use centralized error classification error_type, needs_refresh = DatabaseErrorPatterns.classify_error(e) - if DatabaseErrorPatterns.is_retryable_error(error_type): - # For database unavailable errors, use extended settings if configured - current_max_retries = max_retries - current_base_delay = base_delay - current_max_delay = max_delay - - if ( - error_type == DatabaseErrorType.DATABASE_UNAVAILABLE - and retry_count == 0 - ): - extended_settings = RetryConfiguration.get_retry_settings( - use_extended=True + if not DatabaseErrorPatterns.is_retryable_error(error_type): + logger.debug( + LogMessages.format_message( + LogMessages.NON_RETRYABLE_ERROR, error=e ) - current_max_retries = extended_settings["max_retries"] - current_base_delay = extended_settings["base_delay"] - current_max_delay = extended_settings["max_delay"] + ) + raise - if retry_count < current_max_retries: - delay = min( - current_base_delay * (2**retry_count), current_max_delay + # Get appropriate retry settings for this error type + current_settings = _get_retry_settings_for_error( + error_type, retry_count, max_retries, base_delay, max_delay + ) + + if retry_count < current_settings["max_retries"]: + delay = min( + current_settings["base_delay"] * (2**retry_count), + current_settings["max_delay"], + ) + retry_count += 1 + + # Handle connection pool refresh if needed + if needs_refresh: + _handle_pool_refresh( + func, e, retry_count, current_settings["max_retries"] + 1 ) - retry_count += 1 - - # Handle SQLAlchemy connection pool disposal for severe connection issues - if needs_refresh: - logger.warning( - LogMessages.format_message( - LogMessages.POOL_CORRUPTION_DETECTED, - attempt=retry_count, - total=current_max_retries + 1, - error=e, - ) - ) - try: - _dispose_sqlalchemy_engine(func) - logger.info( - "SQLAlchemy connection pool disposed successfully" - ) - except Exception as refresh_error: - logger.warning( - LogMessages.format_message( - LogMessages.POOL_REFRESH_FAILED, - error=refresh_error, - ) - ) - else: - # Choose appropriate retry message based on error type - if error_type == DatabaseErrorType.DATABASE_UNAVAILABLE: - message = LogMessages.DATABASE_UNAVAILABLE_RETRY - else: - message = LogMessages.CONNECTION_ERROR_RETRY - - logger.warning( - LogMessages.format_message( - message, - attempt=retry_count, - total=current_max_retries + 1, - error=e, - delay=delay, - ) - ) - - time.sleep(delay) - continue else: - logger.error( - LogMessages.format_message( - LogMessages.MAX_RETRIES_EXCEEDED, - total=current_max_retries + 1, - error=e, - ) + _log_retry_attempt( + error_type, + retry_count, + current_settings["max_retries"] + 1, + e, + delay, ) - raise + + time.sleep(delay) + continue else: - # Not a connection error, re-raise immediately - logger.debug( + logger.error( LogMessages.format_message( - LogMessages.NON_RETRYABLE_ERROR, error=e + LogMessages.MAX_RETRIES_EXCEEDED, + total=current_settings["max_retries"] + 1, + error=e, ) ) raise @@ -184,7 +223,7 @@ def wrapper(*args, **kwargs) -> Any: return decorator -def _dispose_sqlalchemy_engine(func): +def _dispose_sqlalchemy_engine(): """Dispose SQLAlchemy engine to force connection pool recreation. This is called when we detect severe connection issues that require diff --git a/backend/utils/db_retry.py b/backend/utils/db_retry.py index 752284e2f1..f3d1d9d7bb 100644 --- a/backend/utils/db_retry.py +++ b/backend/utils/db_retry.py @@ -110,6 +110,76 @@ def _classify_database_error(error: Exception): return error_type, needs_refresh, use_extended_retry +def _update_retry_settings_for_extended(retry_count: int, use_extended: bool) -> dict: + """Get extended retry settings if needed for database unavailable errors.""" + if use_extended and retry_count == 0: + return get_default_retry_settings(use_extended=True) + return {} + + +def _handle_connection_pool_refresh( + error: Exception, retry_count: int, max_retries: int +) -> None: + """Handle connection pool refresh for stale connections.""" + logger.warning( + LogMessages.format_message( + LogMessages.POOL_CORRUPTION_DETECTED, + attempt=retry_count, + total=max_retries + 1, + error=error, + ) + ) + try: + close_old_connections() + logger.info(LogMessages.POOL_REFRESH_SUCCESS) + except Exception as refresh_error: + logger.warning( + LogMessages.format_message( + LogMessages.POOL_REFRESH_FAILED, error=refresh_error + ) + ) + + +def _log_retry_attempt( + error_type: DatabaseErrorType, + retry_count: int, + max_retries: int, + error: Exception, + delay: float, +) -> None: + """Log retry attempt with appropriate message based on error type.""" + if error_type == DatabaseErrorType.DATABASE_UNAVAILABLE: + message = LogMessages.DATABASE_UNAVAILABLE_RETRY + else: + message = LogMessages.CONNECTION_ERROR_RETRY + + logger.warning( + LogMessages.format_message( + message, + attempt=retry_count, + total=max_retries + 1, + error=error, + delay=delay, + ) + ) + + +def _handle_retry_attempt( + error: Exception, + error_type: DatabaseErrorType, + needs_refresh: bool, + force_refresh: bool, + retry_count: int, + max_retries: int, + delay: float, +) -> None: + """Handle a single retry attempt - either refresh pool or log retry.""" + if needs_refresh and force_refresh: + _handle_connection_pool_refresh(error, retry_count, max_retries) + else: + _log_retry_attempt(error_type, retry_count, max_retries, error, delay) + + def _execute_with_retry( operation: Callable, max_retries: int = 3, @@ -147,9 +217,11 @@ def _execute_with_retry( ) raise - # For database unavailable errors, use extended settings if configured - if use_extended and retry_count == 0: - extended_settings = get_default_retry_settings(use_extended=True) + # Update settings for extended retry if needed + extended_settings = _update_retry_settings_for_extended( + retry_count, use_extended + ) + if extended_settings: max_retries = extended_settings["max_retries"] base_delay = extended_settings["base_delay"] max_delay = extended_settings["max_delay"] @@ -158,41 +230,15 @@ def _execute_with_retry( delay = min(base_delay * (2**retry_count), max_delay) retry_count += 1 - # Handle connection pool refresh for stale connections - if needs_refresh and force_refresh: - logger.warning( - LogMessages.format_message( - LogMessages.POOL_CORRUPTION_DETECTED, - attempt=retry_count, - total=max_retries + 1, - error=e, - ) - ) - try: - close_old_connections() - logger.info(LogMessages.POOL_REFRESH_SUCCESS) - except Exception as refresh_error: - logger.warning( - LogMessages.format_message( - LogMessages.POOL_REFRESH_FAILED, error=refresh_error - ) - ) - else: - # Choose appropriate retry message based on error type - if error_type == DatabaseErrorType.DATABASE_UNAVAILABLE: - message = LogMessages.DATABASE_UNAVAILABLE_RETRY - else: - message = LogMessages.CONNECTION_ERROR_RETRY - - logger.warning( - LogMessages.format_message( - message, - attempt=retry_count, - total=max_retries + 1, - error=e, - delay=delay, - ) - ) + _handle_retry_attempt( + e, + error_type, + needs_refresh, + force_refresh, + retry_count, + max_retries, + delay, + ) time.sleep(delay) continue @@ -268,6 +314,33 @@ def operation(): return decorator +def _handle_context_retry_logic( + retry_count: int, + final_max_retries: int, + final_base_delay: float, + final_max_delay: float, + final_force_refresh: bool, + error: Exception, + error_type: DatabaseErrorType, + needs_refresh: bool, +) -> tuple[int, float]: + """Handle retry logic for context manager - returns updated retry_count and delay.""" + delay = min(final_base_delay * (2**retry_count), final_max_delay) + retry_count += 1 + + _handle_retry_attempt( + error, + error_type, + needs_refresh, + final_force_refresh, + retry_count, + final_max_retries, + delay, + ) + + return retry_count, delay + + @contextmanager def db_retry_context( max_retries: int | None = None, @@ -321,53 +394,26 @@ def db_retry_context( ) raise - # For database unavailable errors, use extended settings if configured - if use_extended and retry_count == 0: - extended_settings = get_default_retry_settings(use_extended=True) + # Update settings for extended retry if needed + extended_settings = _update_retry_settings_for_extended( + retry_count, use_extended + ) + if extended_settings: final_max_retries = extended_settings["max_retries"] final_base_delay = extended_settings["base_delay"] final_max_delay = extended_settings["max_delay"] if retry_count < final_max_retries: - delay = min(final_base_delay * (2**retry_count), final_max_delay) - retry_count += 1 - - # Handle connection pool refresh for stale connections - if needs_refresh and final_force_refresh: - logger.warning( - LogMessages.format_message( - LogMessages.POOL_CORRUPTION_DETECTED, - attempt=retry_count, - total=final_max_retries + 1, - error=e, - ) - ) - try: - close_old_connections() - logger.info(LogMessages.POOL_REFRESH_SUCCESS) - except Exception as refresh_error: - logger.warning( - LogMessages.format_message( - LogMessages.POOL_REFRESH_FAILED, error=refresh_error - ) - ) - else: - # Choose appropriate retry message based on error type - if error_type == DatabaseErrorType.DATABASE_UNAVAILABLE: - message = LogMessages.DATABASE_UNAVAILABLE_RETRY - else: - message = LogMessages.CONNECTION_ERROR_RETRY - - logger.warning( - LogMessages.format_message( - message, - attempt=retry_count, - total=final_max_retries + 1, - error=e, - delay=delay, - ) - ) - + retry_count, delay = _handle_context_retry_logic( + retry_count, + final_max_retries, + final_base_delay, + final_max_delay, + final_force_refresh, + e, + error_type, + needs_refresh, + ) time.sleep(delay) continue else: From 78a8660a3768d135bbbe7fa7a8d9a922fe0621b4 Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 9 Sep 2025 19:21:53 +0530 Subject: [PATCH 3/6] reducing complexity of methods --- backend/backend/celery_db_retry.py | 101 +++++++++++++++++------------ backend/utils/db_retry.py | 85 +++++++++++++++--------- 2 files changed, 114 insertions(+), 72 deletions(-) diff --git a/backend/backend/celery_db_retry.py b/backend/backend/celery_db_retry.py index 79dd7d9657..817a23abb9 100644 --- a/backend/backend/celery_db_retry.py +++ b/backend/backend/celery_db_retry.py @@ -57,9 +57,7 @@ def _get_retry_settings_for_error( } -def _handle_pool_refresh( - func: Callable, error: Exception, retry_count: int, total_retries: int -) -> None: +def _handle_pool_refresh(error: Exception, retry_count: int, total_retries: int) -> None: """Handle SQLAlchemy connection pool disposal.""" logger.warning( LogMessages.format_message( @@ -105,6 +103,59 @@ def _log_retry_attempt( ) +def _handle_non_sqlalchemy_error(error: Exception) -> None: + """Handle non-SQLAlchemy errors by logging and re-raising.""" + logger.debug(LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=error)) + raise error + + +def _handle_non_retryable_error(error: Exception) -> None: + """Handle non-retryable errors by logging and re-raising.""" + logger.debug(LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=error)) + raise error + + +def _execute_retry_attempt( + error: Exception, + error_type: DatabaseErrorType, + needs_refresh: bool, + retry_count: int, + current_settings: dict, +) -> tuple[int, float]: + """Execute a retry attempt and return updated retry_count and delay.""" + delay = min( + current_settings["base_delay"] * (2**retry_count), + current_settings["max_delay"], + ) + retry_count += 1 + + # Handle connection pool refresh if needed + if needs_refresh: + _handle_pool_refresh(error, retry_count, current_settings["max_retries"] + 1) + else: + _log_retry_attempt( + error_type, + retry_count, + current_settings["max_retries"] + 1, + error, + delay, + ) + + return retry_count, delay + + +def _handle_max_retries_exceeded(error: Exception, total_retries: int) -> None: + """Handle the case when max retries are exceeded.""" + logger.error( + LogMessages.format_message( + LogMessages.MAX_RETRIES_EXCEEDED, + total=total_retries, + error=error, + ) + ) + raise error + + def _log_success(retry_count: int) -> None: """Log operation success.""" if retry_count == 0: @@ -159,23 +210,13 @@ def wrapper(*args, **kwargs) -> Any: return result except Exception as e: if not _is_sqlalchemy_error(e): - logger.debug( - LogMessages.format_message( - LogMessages.NON_RETRYABLE_ERROR, error=e - ) - ) - raise + _handle_non_sqlalchemy_error(e) # Use centralized error classification error_type, needs_refresh = DatabaseErrorPatterns.classify_error(e) if not DatabaseErrorPatterns.is_retryable_error(error_type): - logger.debug( - LogMessages.format_message( - LogMessages.NON_RETRYABLE_ERROR, error=e - ) - ) - raise + _handle_non_retryable_error(e) # Get appropriate retry settings for this error type current_settings = _get_retry_settings_for_error( @@ -183,37 +224,15 @@ def wrapper(*args, **kwargs) -> Any: ) if retry_count < current_settings["max_retries"]: - delay = min( - current_settings["base_delay"] * (2**retry_count), - current_settings["max_delay"], + retry_count, delay = _execute_retry_attempt( + e, error_type, needs_refresh, retry_count, current_settings ) - retry_count += 1 - - # Handle connection pool refresh if needed - if needs_refresh: - _handle_pool_refresh( - func, e, retry_count, current_settings["max_retries"] + 1 - ) - else: - _log_retry_attempt( - error_type, - retry_count, - current_settings["max_retries"] + 1, - e, - delay, - ) - time.sleep(delay) continue else: - logger.error( - LogMessages.format_message( - LogMessages.MAX_RETRIES_EXCEEDED, - total=current_settings["max_retries"] + 1, - error=e, - ) + _handle_max_retries_exceeded( + e, current_settings["max_retries"] + 1 ) - raise # This should never be reached, but included for completeness return func(*args, **kwargs) diff --git a/backend/utils/db_retry.py b/backend/utils/db_retry.py index f3d1d9d7bb..ad91f07660 100644 --- a/backend/utils/db_retry.py +++ b/backend/utils/db_retry.py @@ -314,6 +314,39 @@ def operation(): return decorator +def _handle_context_non_retryable_error(error: Exception) -> None: + """Handle non-retryable errors in context manager.""" + logger.debug(LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=error)) + raise error + + +def _handle_context_max_retries_exceeded(error: Exception, total_retries: int) -> None: + """Handle max retries exceeded in context manager.""" + logger.error( + LogMessages.format_message( + LogMessages.MAX_RETRIES_EXCEEDED, + total=total_retries, + error=error, + ) + ) + raise error + + +def _update_context_retry_settings( + retry_count: int, use_extended: bool, current_settings: dict +) -> dict: + """Update retry settings for extended retry if needed.""" + extended_settings = _update_retry_settings_for_extended(retry_count, use_extended) + if extended_settings: + return { + "max_retries": extended_settings["max_retries"], + "base_delay": extended_settings["base_delay"], + "max_delay": extended_settings["max_delay"], + "force_refresh": current_settings["force_refresh"], + } + return current_settings + + def _handle_context_retry_logic( retry_count: int, final_max_retries: int, @@ -370,18 +403,20 @@ def db_retry_context( """ # Get defaults from environment if not provided defaults = get_default_retry_settings() - final_max_retries = ( - max_retries if max_retries is not None else defaults["max_retries"] - ) - final_base_delay = base_delay if base_delay is not None else defaults["base_delay"] - final_max_delay = max_delay if max_delay is not None else defaults["max_delay"] - final_force_refresh = ( - force_refresh if force_refresh is not None else defaults["force_refresh"] - ) + current_settings = { + "max_retries": max_retries + if max_retries is not None + else defaults["max_retries"], + "base_delay": base_delay if base_delay is not None else defaults["base_delay"], + "max_delay": max_delay if max_delay is not None else defaults["max_delay"], + "force_refresh": force_refresh + if force_refresh is not None + else defaults["force_refresh"], + } retry_count = 0 - while retry_count <= final_max_retries: + while retry_count <= current_settings["max_retries"]: try: yield return # Success - exit the retry loop @@ -389,27 +424,20 @@ def db_retry_context( error_type, needs_refresh, use_extended = _classify_database_error(e) if not DatabaseErrorPatterns.is_retryable_error(error_type): - logger.debug( - LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=e) - ) - raise + _handle_context_non_retryable_error(e) # Update settings for extended retry if needed - extended_settings = _update_retry_settings_for_extended( - retry_count, use_extended + current_settings = _update_context_retry_settings( + retry_count, use_extended, current_settings ) - if extended_settings: - final_max_retries = extended_settings["max_retries"] - final_base_delay = extended_settings["base_delay"] - final_max_delay = extended_settings["max_delay"] - if retry_count < final_max_retries: + if retry_count < current_settings["max_retries"]: retry_count, delay = _handle_context_retry_logic( retry_count, - final_max_retries, - final_base_delay, - final_max_delay, - final_force_refresh, + current_settings["max_retries"], + current_settings["base_delay"], + current_settings["max_delay"], + current_settings["force_refresh"], e, error_type, needs_refresh, @@ -417,14 +445,9 @@ def db_retry_context( time.sleep(delay) continue else: - logger.error( - LogMessages.format_message( - LogMessages.MAX_RETRIES_EXCEEDED, - total=final_max_retries + 1, - error=e, - ) + _handle_context_max_retries_exceeded( + e, current_settings["max_retries"] + 1 ) - raise def retry_database_operation( From 7c0d03cc586e0b01ca9940aa8dce7855ecb8d9ad Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 9 Sep 2025 19:38:36 +0530 Subject: [PATCH 4/6] reducing complexity of methods --- backend/backend/celery_db_retry.py | 81 +++++++++++++++++++----------- backend/utils/db_retry.py | 18 ++++--- 2 files changed, 61 insertions(+), 38 deletions(-) diff --git a/backend/backend/celery_db_retry.py b/backend/backend/celery_db_retry.py index 817a23abb9..8a4cbfa96d 100644 --- a/backend/backend/celery_db_retry.py +++ b/backend/backend/celery_db_retry.py @@ -57,7 +57,9 @@ def _get_retry_settings_for_error( } -def _handle_pool_refresh(error: Exception, retry_count: int, total_retries: int) -> None: +def _handle_pool_refresh( + error: BaseException, retry_count: int, total_retries: int +) -> None: """Handle SQLAlchemy connection pool disposal.""" logger.warning( LogMessages.format_message( @@ -83,7 +85,7 @@ def _log_retry_attempt( error_type: DatabaseErrorType, retry_count: int, total_retries: int, - error: Exception, + error: BaseException, delay: float, ) -> None: """Log retry attempt with appropriate message.""" @@ -103,20 +105,20 @@ def _log_retry_attempt( ) -def _handle_non_sqlalchemy_error(error: Exception) -> None: +def _handle_non_sqlalchemy_error(error: BaseException) -> None: """Handle non-SQLAlchemy errors by logging and re-raising.""" logger.debug(LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=error)) - raise error + raise -def _handle_non_retryable_error(error: Exception) -> None: +def _handle_non_retryable_error(error: BaseException) -> None: """Handle non-retryable errors by logging and re-raising.""" logger.debug(LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=error)) - raise error + raise def _execute_retry_attempt( - error: Exception, + error: BaseException, error_type: DatabaseErrorType, needs_refresh: bool, retry_count: int, @@ -144,7 +146,7 @@ def _execute_retry_attempt( return retry_count, delay -def _handle_max_retries_exceeded(error: Exception, total_retries: int) -> None: +def _handle_max_retries_exceeded(error: BaseException, total_retries: int) -> None: """Handle the case when max retries are exceeded.""" logger.error( LogMessages.format_message( @@ -153,7 +155,43 @@ def _handle_max_retries_exceeded(error: Exception, total_retries: int) -> None: error=error, ) ) - raise error + raise + + +def _process_celery_exception( + error: BaseException, + retry_count: int, + max_retries: int, + base_delay: float, + max_delay: float, +) -> tuple[int, float] | None: + """Process a Celery exception and return retry info or None if not retryable. + + Returns: + tuple[int, float]: (updated_retry_count, delay) if should retry + None: if should not retry (will re-raise) + """ + if not _is_sqlalchemy_error(error): + _handle_non_sqlalchemy_error(error) + + # Use centralized error classification + error_type, needs_refresh = DatabaseErrorPatterns.classify_error(error) + + if not DatabaseErrorPatterns.is_retryable_error(error_type): + _handle_non_retryable_error(error) + + # Get appropriate retry settings for this error type + current_settings = _get_retry_settings_for_error( + error_type, retry_count, max_retries, base_delay, max_delay + ) + + if retry_count < current_settings["max_retries"]: + return _execute_retry_attempt( + error, error_type, needs_refresh, retry_count, current_settings + ) + else: + _handle_max_retries_exceeded(error, current_settings["max_retries"] + 1) + return None # This line will never be reached due to exception, but added for completeness def _log_success(retry_count: int) -> None: @@ -209,30 +247,13 @@ def wrapper(*args, **kwargs) -> Any: _log_success(retry_count) return result except Exception as e: - if not _is_sqlalchemy_error(e): - _handle_non_sqlalchemy_error(e) - - # Use centralized error classification - error_type, needs_refresh = DatabaseErrorPatterns.classify_error(e) - - if not DatabaseErrorPatterns.is_retryable_error(error_type): - _handle_non_retryable_error(e) - - # Get appropriate retry settings for this error type - current_settings = _get_retry_settings_for_error( - error_type, retry_count, max_retries, base_delay, max_delay + retry_info = _process_celery_exception( + e, retry_count, max_retries, base_delay, max_delay ) - - if retry_count < current_settings["max_retries"]: - retry_count, delay = _execute_retry_attempt( - e, error_type, needs_refresh, retry_count, current_settings - ) + if retry_info: + retry_count, delay = retry_info time.sleep(delay) continue - else: - _handle_max_retries_exceeded( - e, current_settings["max_retries"] + 1 - ) # This should never be reached, but included for completeness return func(*args, **kwargs) diff --git a/backend/utils/db_retry.py b/backend/utils/db_retry.py index ad91f07660..1d81ce8a64 100644 --- a/backend/utils/db_retry.py +++ b/backend/utils/db_retry.py @@ -118,7 +118,7 @@ def _update_retry_settings_for_extended(retry_count: int, use_extended: bool) -> def _handle_connection_pool_refresh( - error: Exception, retry_count: int, max_retries: int + error: BaseException, retry_count: int, max_retries: int ) -> None: """Handle connection pool refresh for stale connections.""" logger.warning( @@ -144,7 +144,7 @@ def _log_retry_attempt( error_type: DatabaseErrorType, retry_count: int, max_retries: int, - error: Exception, + error: BaseException, delay: float, ) -> None: """Log retry attempt with appropriate message based on error type.""" @@ -165,7 +165,7 @@ def _log_retry_attempt( def _handle_retry_attempt( - error: Exception, + error: BaseException, error_type: DatabaseErrorType, needs_refresh: bool, force_refresh: bool, @@ -314,13 +314,15 @@ def operation(): return decorator -def _handle_context_non_retryable_error(error: Exception) -> None: +def _handle_context_non_retryable_error(error: BaseException) -> None: """Handle non-retryable errors in context manager.""" logger.debug(LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=error)) - raise error + raise -def _handle_context_max_retries_exceeded(error: Exception, total_retries: int) -> None: +def _handle_context_max_retries_exceeded( + error: BaseException, total_retries: int +) -> None: """Handle max retries exceeded in context manager.""" logger.error( LogMessages.format_message( @@ -329,7 +331,7 @@ def _handle_context_max_retries_exceeded(error: Exception, total_retries: int) - error=error, ) ) - raise error + raise def _update_context_retry_settings( @@ -353,7 +355,7 @@ def _handle_context_retry_logic( final_base_delay: float, final_max_delay: float, final_force_refresh: bool, - error: Exception, + error: BaseException, error_type: DatabaseErrorType, needs_refresh: bool, ) -> tuple[int, float]: From 0eeae21db34cad5f3a1f17e6deab69cce68f20fc Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 9 Sep 2025 19:53:06 +0530 Subject: [PATCH 5/6] reducing complexity of methods --- backend/backend/celery_db_retry.py | 25 ++++++++++++------------- backend/utils/db_retry.py | 20 ++++++++------------ 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/backend/backend/celery_db_retry.py b/backend/backend/celery_db_retry.py index 8a4cbfa96d..f555df5288 100644 --- a/backend/backend/celery_db_retry.py +++ b/backend/backend/celery_db_retry.py @@ -105,16 +105,14 @@ def _log_retry_attempt( ) -def _handle_non_sqlalchemy_error(error: BaseException) -> None: - """Handle non-SQLAlchemy errors by logging and re-raising.""" +def _log_non_sqlalchemy_error(error: BaseException) -> None: + """Log non-SQLAlchemy errors.""" logger.debug(LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=error)) - raise -def _handle_non_retryable_error(error: BaseException) -> None: - """Handle non-retryable errors by logging and re-raising.""" +def _log_non_retryable_error(error: BaseException) -> None: + """Log non-retryable errors.""" logger.debug(LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=error)) - raise def _execute_retry_attempt( @@ -146,8 +144,8 @@ def _execute_retry_attempt( return retry_count, delay -def _handle_max_retries_exceeded(error: BaseException, total_retries: int) -> None: - """Handle the case when max retries are exceeded.""" +def _log_max_retries_exceeded(error: BaseException, total_retries: int) -> None: + """Log when max retries are exceeded.""" logger.error( LogMessages.format_message( LogMessages.MAX_RETRIES_EXCEEDED, @@ -155,7 +153,6 @@ def _handle_max_retries_exceeded(error: BaseException, total_retries: int) -> No error=error, ) ) - raise def _process_celery_exception( @@ -172,13 +169,15 @@ def _process_celery_exception( None: if should not retry (will re-raise) """ if not _is_sqlalchemy_error(error): - _handle_non_sqlalchemy_error(error) + _log_non_sqlalchemy_error(error) + raise # Use centralized error classification error_type, needs_refresh = DatabaseErrorPatterns.classify_error(error) if not DatabaseErrorPatterns.is_retryable_error(error_type): - _handle_non_retryable_error(error) + _log_non_retryable_error(error) + raise # Get appropriate retry settings for this error type current_settings = _get_retry_settings_for_error( @@ -190,8 +189,8 @@ def _process_celery_exception( error, error_type, needs_refresh, retry_count, current_settings ) else: - _handle_max_retries_exceeded(error, current_settings["max_retries"] + 1) - return None # This line will never be reached due to exception, but added for completeness + _log_max_retries_exceeded(error, current_settings["max_retries"] + 1) + raise def _log_success(retry_count: int) -> None: diff --git a/backend/utils/db_retry.py b/backend/utils/db_retry.py index 1d81ce8a64..66e264db23 100644 --- a/backend/utils/db_retry.py +++ b/backend/utils/db_retry.py @@ -314,16 +314,13 @@ def operation(): return decorator -def _handle_context_non_retryable_error(error: BaseException) -> None: - """Handle non-retryable errors in context manager.""" +def _log_context_non_retryable_error(error: BaseException) -> None: + """Log non-retryable errors in context manager.""" logger.debug(LogMessages.format_message(LogMessages.NON_RETRYABLE_ERROR, error=error)) - raise -def _handle_context_max_retries_exceeded( - error: BaseException, total_retries: int -) -> None: - """Handle max retries exceeded in context manager.""" +def _log_context_max_retries_exceeded(error: BaseException, total_retries: int) -> None: + """Log max retries exceeded in context manager.""" logger.error( LogMessages.format_message( LogMessages.MAX_RETRIES_EXCEEDED, @@ -331,7 +328,6 @@ def _handle_context_max_retries_exceeded( error=error, ) ) - raise def _update_context_retry_settings( @@ -426,7 +422,8 @@ def db_retry_context( error_type, needs_refresh, use_extended = _classify_database_error(e) if not DatabaseErrorPatterns.is_retryable_error(error_type): - _handle_context_non_retryable_error(e) + _log_context_non_retryable_error(e) + raise # Update settings for extended retry if needed current_settings = _update_context_retry_settings( @@ -447,9 +444,8 @@ def db_retry_context( time.sleep(delay) continue else: - _handle_context_max_retries_exceeded( - e, current_settings["max_retries"] + 1 - ) + _log_context_max_retries_exceeded(e, current_settings["max_retries"] + 1) + raise def retry_database_operation( From e606f1eb94d89f32e90c000c3a83e825ee1bafb7 Mon Sep 17 00:00:00 2001 From: ali Date: Wed, 10 Sep 2025 08:20:21 +0530 Subject: [PATCH 6/6] use env values for configurations, rplace print with logger --- backend/backend/celery_config.py | 9 +++++---- backend/backend/celery_db_retry.py | 28 +++++++++++++++++++--------- backend/sample.env | 14 ++++++++++---- 3 files changed, 34 insertions(+), 17 deletions(-) diff --git a/backend/backend/celery_config.py b/backend/backend/celery_config.py index 3629af6553..80c2299fc0 100644 --- a/backend/backend/celery_config.py +++ b/backend/backend/celery_config.py @@ -1,3 +1,4 @@ +import logging import os from urllib.parse import quote_plus @@ -5,6 +6,8 @@ from backend.celery_db_retry import get_celery_db_engine_options, should_use_builtin_retry +logger = logging.getLogger(__name__) + class CeleryConfig: """Specifies celery configuration with hybrid retry support. @@ -57,12 +60,10 @@ class CeleryConfig: os.environ.get("CELERY_RESULT_BACKEND_MAX_SLEEP_BETWEEN_RETRIES_MS", "30000") ) - print( + logger.info( f"[Celery Config] Using built-in retry: max_retries={result_backend_max_retries}, " - ) - print( f"base_sleep={result_backend_base_sleep_between_retries_ms}ms, max_sleep={result_backend_max_sleep_between_retries_ms}ms" ) else: # Custom retry is handled by patch_celery_database_backend() - print("[Celery Config] Using custom retry system (patching enabled)") + logger.info("[Celery Config] Using custom retry system (patching enabled)") diff --git a/backend/backend/celery_db_retry.py b/backend/backend/celery_db_retry.py index f555df5288..1fb4550dfd 100644 --- a/backend/backend/celery_db_retry.py +++ b/backend/backend/celery_db_retry.py @@ -214,7 +214,7 @@ def celery_db_retry_with_backoff( """Decorator to retry Celery database backend operations with exponential backoff. This is specifically designed for Celery's database result backend operations - that may experience connection drops when using PgBouncer or database restarts. + that may experience connection drops when using database proxy or database restarts. Args: max_retries: Maximum number of retry attempts (defaults to settings or 3) @@ -419,23 +419,33 @@ def _configure_builtin_retry(): def get_celery_db_engine_options(): - """Get SQLAlchemy engine options optimized for use with PgBouncer. + """Get SQLAlchemy engine options optimized for database connection pooling. Includes built-in retry configuration if CELERY_USE_BUILTIN_RETRY is enabled. - These options are designed to work well with PgBouncer connection pooling - without interfering with PgBouncer's pool management. + These options are designed to work well with database proxy connection pooling + without interfering with the database proxy's pool management. + + All options are configurable via environment variables for flexibility. Returns: dict: SQLAlchemy engine options """ return { # Connection health checking - "pool_pre_ping": True, # Test connections before use - # Minimal pooling (let PgBouncer handle the real pooling) - "pool_size": 5, # Small pool since PgBouncer handles real pooling - "max_overflow": 0, # No overflow, rely on PgBouncer - "pool_recycle": 3600, # Recycle connections every hour + "pool_pre_ping": RetryConfiguration.get_setting_value( + "CELERY_DB_POOL_PRE_PING", True + ), # Test connections before use + # Minimal pooling (let database proxy handle the real pooling) + "pool_size": RetryConfiguration.get_setting_value( + "CELERY_DB_POOL_SIZE", 5 + ), # Small pool since database proxy handles real pooling + "max_overflow": RetryConfiguration.get_setting_value( + "CELERY_DB_MAX_OVERFLOW", 0 + ), # No overflow, rely on database proxy + "pool_recycle": RetryConfiguration.get_setting_value( + "CELERY_DB_POOL_RECYCLE", 3600 + ), # Recycle connections every hour # Connection timeouts using centralized configuration "connect_args": { "connect_timeout": RetryConfiguration.get_setting_value( diff --git a/backend/sample.env b/backend/sample.env index 8b87125bed..816109f979 100644 --- a/backend/sample.env +++ b/backend/sample.env @@ -26,17 +26,23 @@ DB_SCHEMA="unstract" # Example: # CELERY_BACKEND_DB_NAME=unstract_celery_db -# Database connection retry settings (for handling connection drops with PgBouncer) +# Database connection retry settings (for handling connection drops with database proxy) # DB_CONNECTION_RETRY_COUNT=3 # DB_CONNECTION_RETRY_DELAY=1 # DB_CONNECTION_RETRY_MAX_DELAY=30 -# Celery database backend retry settings (for handling connection drops with PgBouncer) +# Celery database backend retry settings (for handling connection drops with database proxy) # CELERY_DB_RETRY_COUNT=3 # CELERY_DB_RETRY_DELAY=1 # CELERY_DB_RETRY_MAX_DELAY=30 -# CELERY_DB_CONNECT_TIMEOUT=30 -# CELERY_DB_ECHO_SQL=False + +# Celery database backend connection pool settings +# CELERY_DB_POOL_PRE_PING=True # Test connections before use +# CELERY_DB_POOL_SIZE=5 # Connection pool size +# CELERY_DB_MAX_OVERFLOW=0 # Max overflow connections +# CELERY_DB_POOL_RECYCLE=3600 # Recycle connections after (seconds) +# CELERY_DB_CONNECT_TIMEOUT=30 # Connection timeout (seconds) +# CELERY_DB_ECHO_SQL=False # Echo SQL queries for debugging # Redis REDIS_HOST="unstract-redis" REDIS_PORT=6379