Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ac86280
Added Database, Healthcheck, CircuitBreaker, FailureDetector
vladvildanov Jun 13, 2025
4f4a53c
Added DatabaseSelector, exceptions, refactored existing entities
vladvildanov Jun 17, 2025
acc68ef
Added MultiDbConfig
vladvildanov Jun 17, 2025
255bb0e
Added DatabaseConfig
vladvildanov Jun 17, 2025
79db257
Added DatabaseConfig test coverage
vladvildanov Jun 17, 2025
8790db1
Renamed DatabaseSelector into FailoverStrategy
vladvildanov Jun 18, 2025
b3ad8da
Added CommandExecutor
vladvildanov Jun 18, 2025
3a1dc9c
Updated healthcheck to close circuit on success
vladvildanov Jun 18, 2025
9bb9235
Added thread-safeness
vladvildanov Jun 19, 2025
3218e36
Added missing thread-safeness
vladvildanov Jun 19, 2025
4cdb6f4
Added missing thread-safenes for dispatcher
vladvildanov Jun 19, 2025
6914467
Refactored client to keep databases in WeightedList
vladvildanov Jun 19, 2025
5b94757
Added database CRUD operations
vladvildanov Jun 26, 2025
daba501
Added on-fly configuration
vladvildanov Jun 26, 2025
061e518
Added background health checks
vladvildanov Jun 27, 2025
a562774
Added background healthcheck + half-open event
vladvildanov Jul 2, 2025
3ab1367
Refactored background scheduling
vladvildanov Jul 3, 2025
3a55dcd
Merge branch 'feat/active-active' of github.com:redis/redis-py into v…
vladvildanov Jul 4, 2025
badef0e
Refactored healthchecks
vladvildanov Jul 7, 2025
fcc6035
Removed code repetitions, fixed weight assignment, added loops enhanc…
vladvildanov Jul 15, 2025
d5dc65c
Refactored configuration
vladvildanov Jul 17, 2025
7086822
Refactored failure detector
vladvildanov Jul 18, 2025
2561d6f
Refactored retry logic
vladvildanov Jul 18, 2025
a0af5b3
Added scenario tests
vladvildanov Jul 24, 2025
aaed8d7
Added pybreaker optional dependency
vladvildanov Jul 24, 2025
0551618
Added pybreaker to dev dependencies
vladvildanov Jul 24, 2025
1d288e6
Rename tests directory
vladvildanov Jul 24, 2025
6cdca81
Remove redundant checks
vladvildanov Jul 28, 2025
66ba193
Handle retries if default is not set
vladvildanov Aug 11, 2025
7c0c26a
Removed all Sentinel related
vladvildanov Aug 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions redis/data_structure.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import threading
from typing import List, Any, TypeVar, Generic, Union

from redis.typing import Number

T = TypeVar('T')

class WeightedList(Generic[T]):
"""
Thread-safe weighted list.
"""
def __init__(self):
self._items: List[tuple[Any, Union[int, float]]] = []
self._items: List[tuple[Any, Number]] = []
self._lock = threading.RLock()

def add(self, item: Any, weight: float) -> None:
Expand All @@ -18,35 +20,35 @@ def add(self, item: Any, weight: float) -> None:
left, right = 0, len(self._items)
while left < right:
mid = (left + right) // 2
if self._items[mid][0] < weight:
if self._items[mid][1] < weight:
right = mid
else:
left = mid + 1

self._items.insert(left, (weight, item))
self._items.insert(left, (item, weight))

def remove(self, item):
"""Remove first occurrence of item"""
with self._lock:
for i, (weight, stored_item) in enumerate(self._items):
for i, (stored_item, weight) in enumerate(self._items):
if stored_item == item:
self._items.pop(i)
return weight
raise ValueError("Item not found")

def get_by_weight_range(self, min_weight: float, max_weight: float) -> List[tuple[Any, Union[int, float]]]:
def get_by_weight_range(self, min_weight: float, max_weight: float) -> List[tuple[Any, Number]]:
"""Get all items within weight range"""
with self._lock:
result = []
for weight, item in self._items:
for item, weight in self._items:
if min_weight <= weight <= max_weight:
result.append((item, weight))
return result

def get_top_n(self, n: int) -> List[tuple[Any, Union[int, float]]]:
def get_top_n(self, n: int) -> List[tuple[Any, Number]]:
"""Get top N the highest weighted items"""
with self._lock:
return [(item, weight) for weight, item in self._items[:n]]
return [(item, weight) for item, weight in self._items[:n]]

def update_weight(self, item, new_weight: float):
with self._lock:
Expand All @@ -60,14 +62,14 @@ def __iter__(self):
with self._lock:
items_copy = self._items.copy() # Create snapshot as lock released after each 'yield'

for weight, item in items_copy:
for item, weight in items_copy:
yield item, weight

def __len__(self):
with self._lock:
return len(self._items)

def __getitem__(self, index) -> tuple[Any, Union[int, float]]:
def __getitem__(self, index) -> tuple[Any, Number]:
with self._lock:
weight, item = self._items[index]
item, weight = self._items[index]
return item, weight
16 changes: 6 additions & 10 deletions redis/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,13 @@ async def dispatch_async(self, event: object):

def register_listeners(self, event_listeners: Dict[Type[object], List[EventListenerInterface]]):
with self._lock:
for event in event_listeners:
if event in self._event_listeners_mapping:
self._event_listeners_mapping[event] = list(set(self._event_listeners_mapping[event] + event_listeners[event]))
for event_type in event_listeners:
if event_type in self._event_listeners_mapping:
self._event_listeners_mapping[event_type] = list(
set(self._event_listeners_mapping[event_type] + event_listeners[event_type])
)
else:
self._event_listeners_mapping[event] = event_listeners[event]
self._event_listeners_mapping[event_type] = event_listeners[event_type]


class AfterConnectionReleasedEvent:
Expand Down Expand Up @@ -257,11 +259,9 @@ def __init__(
self,
command: tuple,
exception: Exception,
client,
):
self._command = command
self._exception = exception
self._client = client

@property
def command(self) -> tuple:
Expand All @@ -271,10 +271,6 @@ def command(self) -> tuple:
def exception(self) -> Exception:
return self._exception

@property
def client(self):
return self._client

class ReAuthConnectionListener(EventListenerInterface):
"""
Listener that performs re-authentication of given connection.
Expand Down
42 changes: 25 additions & 17 deletions redis/multidb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,27 @@ class MultiDBClient(RedisModuleCommands, CoreCommands, SentinelCommands):
"""
def __init__(self, config: MultiDbConfig):
self._databases = config.databases()
self._health_checks = config.health_checks
self._health_checks = config.default_health_checks() if config.health_checks is None else config.health_checks
self._health_check_interval = config.health_check_interval
self._failure_detectors = config.failure_detectors
self._failover_strategy = config.failover_strategy
self._failure_detectors = config.default_failure_detectors() \
if config.failure_detectors is None else config.failure_detectors
self._failover_strategy = config.default_failover_strategy() \
if config.failover_strategy is None else config.failover_strategy
self._failover_strategy.set_databases(self._databases)
self._auto_fallback_interval = config.auto_fallback_interval
self._event_dispatcher = config.event_dispatcher
self._command_executor = DefaultCommandExecutor(
failure_detectors=self._failure_detectors,
databases=self._databases,
command_retry=config.command_retry,
failover_strategy=self._failover_strategy,
event_dispatcher=self._event_dispatcher,
auto_fallback_interval=self._auto_fallback_interval,
)

for fd in self._failure_detectors:
fd.set_command_executor(command_executor=self._command_executor)

self._initialized = False
self._hc_lock = threading.RLock()
self._bg_scheduler = BackgroundScheduler()
Expand All @@ -52,23 +59,23 @@ def _initialize(self):
self._check_databases_health,
)

is_active_db = False
is_active_db_found = False

for database, weight in self._databases:
# Set on state changed callback for each circuit.
database.circuit.on_state_changed(self._on_circuit_state_change_callback)

# Set states according to a weights and circuit state
if database.circuit.state == CBState.CLOSED and not is_active_db:
if database.circuit.state == CBState.CLOSED and not is_active_db_found:
database.state = DBState.ACTIVE
self._command_executor.active_database = database
is_active_db = True
elif database.circuit.state == CBState.CLOSED and is_active_db:
is_active_db_found = True
elif database.circuit.state == CBState.CLOSED and is_active_db_found:
database.state = DBState.PASSIVE
else:
database.state = DBState.DISCONNECTED

if not is_active_db:
if not is_active_db_found:
raise NoValidDatabaseException('Initial connection failed - no active database found')

self._initialized = True
Expand All @@ -88,6 +95,7 @@ def set_active_database(self, database: AbstractDatabase) -> None:
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break

if not exists:
raise ValueError('Given database is not a member of database list')
Expand Down Expand Up @@ -115,11 +123,13 @@ def add_database(self, database: AbstractDatabase):

highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.add(database, database.weight)
self._change_active_database(database, highest_weighted_db)

if database.weight > highest_weight and database.circuit.state == CBState.CLOSED:
database.state = DBState.ACTIVE
self._command_executor.active_database = database
highest_weighted_db.state = DBState.PASSIVE
def _change_active_database(self, new_database: AbstractDatabase, highest_weight_database: AbstractDatabase):
if new_database.weight > highest_weight_database.weight and new_database.circuit.state == CBState.CLOSED:
new_database.state = DBState.ACTIVE
self._command_executor.active_database = new_database
highest_weight_database.state = DBState.PASSIVE

def remove_database(self, database: Database):
"""
Expand All @@ -141,17 +151,15 @@ def update_database_weight(self, database: AbstractDatabase, weight: float):
for existing_db, _ in self._databases:
if existing_db == database:
exists = True
break

if not exists:
raise ValueError('Given database is not a member of database list')

highest_weighted_db, highest_weight = self._databases.get_top_n(1)[0]
self._databases.update_weight(database, weight)

if weight > highest_weight and database.circuit.state == CBState.CLOSED:
database.state = DBState.ACTIVE
self._command_executor.active_database = database
highest_weighted_db.state = DBState.PASSIVE
database.weight = weight
self._change_active_database(database, highest_weighted_db)

def add_failure_detector(self, failure_detector: FailureDetector):
"""
Expand Down
45 changes: 33 additions & 12 deletions redis/multidb/command_executor.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import socket
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import List, Union, Optional

from redis.exceptions import ConnectionError, TimeoutError
from redis.backoff import NoBackoff
from redis.event import EventDispatcherInterface, OnCommandFailEvent
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
from redis.multidb.database import Database, AbstractDatabase, Databases
from redis.multidb.circuit import State as CBState
from redis.multidb.event import RegisterCommandFailure
from redis.multidb.failover import FailoverStrategy
from redis.multidb.failure_detector import FailureDetector
from redis.retry import Retry


class CommandExecutor(ABC):
Expand Down Expand Up @@ -62,6 +62,12 @@ def auto_fallback_interval(self, auto_fallback_interval: float) -> None:
"""Sets auto-fallback interval."""
pass

@property
@abstractmethod
def command_retry(self) -> Retry:
"""Returns command retry object."""
pass

@abstractmethod
def execute_command(self, *args, **options):
"""Executes a command and returns the result."""
Expand All @@ -74,6 +80,7 @@ def __init__(
self,
failure_detectors: List[FailureDetector],
databases: Databases,
command_retry: Retry,
failover_strategy: FailoverStrategy,
event_dispatcher: EventDispatcherInterface,
auto_fallback_interval: float = DEFAULT_AUTO_FALLBACK_INTERVAL,
Expand All @@ -88,6 +95,7 @@ def __init__(
"""
self._failure_detectors = failure_detectors
self._databases = databases
self._command_retry = command_retry
self._failover_strategy = failover_strategy
self._event_dispatcher = event_dispatcher
self._auto_fallback_interval = auto_fallback_interval
Expand All @@ -107,6 +115,10 @@ def add_failure_detector(self, failure_detector: FailureDetector) -> None:
def databases(self) -> Databases:
return self._databases

@property
def command_retry(self) -> Retry:
return self._command_retry

@property
def active_database(self) -> Optional[AbstractDatabase]:
return self._active_database
Expand All @@ -128,6 +140,24 @@ def auto_fallback_interval(self, auto_fallback_interval: int) -> None:
self._auto_fallback_interval = auto_fallback_interval

def execute_command(self, *args, **options):
self._check_active_database()

return self._command_retry.call_with_retry(
lambda: self._execute_command(*args, **options),
lambda error: self._on_command_fail(error, *args),
)

def _execute_command(self, *args, **options):
self._check_active_database()
return self._active_database.client.execute_command(*args, **options)

def _on_command_fail(self, error, *args):
self._event_dispatcher.dispatch(OnCommandFailEvent(args, error))

def _check_active_database(self):
"""
Checks if active database need to be updated.
"""
if (
self._active_database is None
or self._active_database.circuit.state != CBState.CLOSED
Expand All @@ -139,15 +169,6 @@ def execute_command(self, *args, **options):
self._active_database = self._failover_strategy.database
self._schedule_next_fallback()

try:
return self._active_database.client.execute_command(*args, **options)
except (ConnectionError, TimeoutError, socket.timeout) as e:
# Register command failure
self._event_dispatcher.dispatch(OnCommandFailEvent(args, e, self.active_database.client))

# Retry until failure detector will trigger opening of circuit
return self.execute_command(*args, **options)

def _schedule_next_fallback(self) -> None:
if self._auto_fallback_interval == DEFAULT_AUTO_FALLBACK_INTERVAL:
return
Expand All @@ -158,7 +179,7 @@ def _setup_event_dispatcher(self):
"""
Registers command failure event listener.
"""
event_listener = RegisterCommandFailure(self._failure_detectors, self._databases)
event_listener = RegisterCommandFailure(self._failure_detectors)
self._event_dispatcher.register_listeners({
OnCommandFailEvent: [event_listener],
})
Loading