From 5924368a588c69cc748857becb60ea17e7b0ae98 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 20 Nov 2025 18:28:45 +0200 Subject: [PATCH 1/3] Adding SMIGRATED handling --- redis/_parsers/base.py | 8 +- redis/client.py | 15 + redis/cluster.py | 134 +++- redis/connection.py | 249 ++++++-- redis/maint_notifications.py | 86 ++- .../proxy_server_helpers.py | 90 ++- ...st_cluster_maint_notifications_handling.py | 602 +++++++++++++++++- .../test_maint_notifications.py | 35 +- 8 files changed, 1067 insertions(+), 152 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 2da7a94b3d..c5d4678022 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -14,6 +14,7 @@ OSSNodeMigratedNotification, OSSNodeMigratingNotification, ) +from redis.utils import safe_str if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout @@ -194,8 +195,9 @@ def parse_oss_maintenance_completed_msg(response): # Expected message format is: # SMIGRATED id = response[1] - node_address = response[2] + node_address = safe_str(response[2]) slots = response[3] + return OSSNodeMigratedNotification(id, node_address, slots) @staticmethod @@ -225,9 +227,7 @@ def parse_moving_msg(response): if response[3] is None: host, port = None, None else: - value = response[3] - if isinstance(value, bytes): - value = value.decode() + value = safe_str(response[3]) host, port = value.split(":") port = int(port) if port is not None else None diff --git a/redis/client.py b/redis/client.py index e2712fc3f8..354ae3a68a 100755 --- a/redis/client.py +++ b/redis/client.py @@ -58,6 +58,7 @@ from redis.lock import Lock from redis.maint_notifications import ( MaintNotificationsConfig, + OSSMaintNotificationsHandler, ) from redis.retry import Retry from redis.utils import ( @@ -250,6 +251,9 @@ def __init__( cache_config: Optional[CacheConfig] = None, event_dispatcher: Optional[EventDispatcher] = None, maint_notifications_config: Optional[MaintNotificationsConfig] = None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, ) -> None: """ Initialize a new Redis client. @@ -288,6 +292,11 @@ def __init__( will be enabled by default (logic is included in the connection pool initialization). Argument is ignored when connection_pool is provided. + oss_cluster_maint_notifications_handler: + handler for OSS cluster notifications - see + `redis.maint_notifications.OSSMaintNotificationsHandler` for details. + Only supported with RESP3 + Argument is ignored when connection_pool is provided. """ if event_dispatcher is None: self._event_dispatcher = EventDispatcher() @@ -380,6 +389,12 @@ def __init__( "maint_notifications_config": maint_notifications_config, } ) + if oss_cluster_maint_notifications_handler: + kwargs.update( + { + "oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler, + } + ) connection_pool = ConnectionPool(**kwargs) self._event_dispatcher.dispatch( AfterPooledConnectionsInstantiationEvent( diff --git a/redis/cluster.py b/redis/cluster.py index f06000563a..06c1e37f38 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -52,7 +52,10 @@ WatchError, ) from redis.lock import Lock -from redis.maint_notifications import MaintNotificationsConfig +from redis.maint_notifications import ( + MaintNotificationsConfig, + OSSMaintNotificationsHandler, +) from redis.retry import Retry from redis.utils import ( deprecated_args, @@ -214,6 +217,62 @@ def cleanup_kwargs(**kwargs): return connection_kwargs +class MaintNotificationsAbstractRedisCluster: + """ + Abstract class for handling maintenance notifications logic. + This class is expected to be used as base class together with RedisCluster. + + This class is intended to be used with multiple inheritance! + + All logic related to maintenance notifications is encapsulated in this class. + """ + + def __init__( + self, + maint_notifications_config: Optional[MaintNotificationsConfig], + **kwargs, + ): + # Initialize maintenance notifications + is_protocol_supported = kwargs.get("protocol") in [3, "3"] + if maint_notifications_config is None and is_protocol_supported: + maint_notifications_config = MaintNotificationsConfig() + + self.maint_notifications_config = maint_notifications_config + + if maint_notifications_config and maint_notifications_config.enabled: + if not is_protocol_supported: + raise RedisError( + "Maintenance notifications handlers on connection are only supported with RESP version 3" + ) + self._oss_cluster_maint_notifications_handler = ( + OSSMaintNotificationsHandler(self, maint_notifications_config) + ) + # Update connection kwargs for all future nodes connections + self._update_connection_kwargs_for_maint_notifications( + self._oss_cluster_maint_notifications_handler + ) + # Update existing nodes connections - they are created as part of the RedsiCluster constructor + for node in self.get_nodes(): + node.redis_connection.connection_pool.update_maint_notifications_config( + self.maint_notifications_config, + oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler, + ) + else: + self._oss_cluster_maint_notifications_handler = None + + def _update_connection_kwargs_for_maint_notifications( + self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler + ): + """ + Update the connection kwargs for all future connections. + """ + self.nodes_manager.connection_kwargs.update( + { + "oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler, + } + ) + + class AbstractRedisCluster: RedisClusterRequestTTL = 16 @@ -461,7 +520,9 @@ def replace_default_node(self, target_node: "ClusterNode" = None) -> None: self.nodes_manager.default_node = random.choice(replicas) -class RedisCluster(AbstractRedisCluster, RedisClusterCommands): +class RedisCluster( + AbstractRedisCluster, MaintNotificationsAbstractRedisCluster, RedisClusterCommands +): @classmethod def from_url(cls, url, **kwargs): """ @@ -612,8 +673,7 @@ def __init__( `redis.maint_notifications.MaintNotificationsConfig` for details. Only supported with RESP3. If not provided and protocol is RESP3, the maintenance notifications - will be enabled by default (logic is included in the NodesManager - initialization). + will be enabled by default. :**kwargs: Extra arguments that will be sent into Redis instance when created (See Official redis-py doc for supported kwargs - the only limitation @@ -698,6 +758,13 @@ def __init__( if (cache_config or cache) and protocol not in [3, "3"]: raise RedisError("Client caching is only supported with RESP version 3") + if maint_notifications_config and protocol not in [3, "3"]: + raise RedisError( + "Maintenance notifications are only supported with RESP version 3" + ) + if protocol in [3, "3"] and maint_notifications_config is None: + maint_notifications_config = MaintNotificationsConfig() + self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.node_flags = self.__class__.NODE_FLAGS.copy() self.read_from_replicas = read_from_replicas @@ -709,6 +776,7 @@ def __init__( else: self._event_dispatcher = event_dispatcher self.startup_nodes = startup_nodes + self.nodes_manager = NodesManager( startup_nodes=startup_nodes, from_url=from_url, @@ -763,6 +831,10 @@ def __init__( self._aggregate_nodes = None self._lock = threading.RLock() + MaintNotificationsAbstractRedisCluster.__init__( + self, maint_notifications_config, **kwargs + ) + def __enter__(self): return self @@ -1632,9 +1704,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, cache_factory: Optional[CacheFactoryInterface] = None, event_dispatcher: Optional[EventDispatcher] = None, - maint_notifications_config: Optional[ - MaintNotificationsConfig - ] = MaintNotificationsConfig(), + maint_notifications_config: Optional[MaintNotificationsConfig] = None, **kwargs, ): self.nodes_cache: Dict[str, Redis] = {} @@ -1879,11 +1949,29 @@ def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache): return target_node - def initialize(self): + def initialize( + self, + additional_startup_nodes_info: List[Tuple[str, int]] = [], + disconect_startup_nodes_pools: bool = True, + ): """ Initializes the nodes cache, slots cache and redis connections. :startup_nodes: Responsible for discovering other nodes in the cluster + :disconect_startup_nodes_pools: + Whether to disconnect the connection pool of the startup nodes + after the initialization is complete. This is useful when the + startup nodes are not part of the cluster and we want to avoid + keeping the connection open. + :additional_startup_nodes_info: + Additional nodes to add temporarily to the startup nodes. + The additional nodes will be used just in the process of extraction of the slots + and nodes information from the cluster. + This is useful when we want to add new nodes to the cluster + and initialize the client + with them. + The format of the list is a list of tuples, where each tuple contains + the host and port of the node. """ self.reset() tmp_nodes_cache = {} @@ -1893,9 +1981,25 @@ def initialize(self): fully_covered = False kwargs = self.connection_kwargs exception = None + + # Create cache if it's not provided and cache config is set + # should be done before initializing the first connection + # so that it will be applied to all connections + if self._cache is None and self._cache_config is not None: + if self._cache_factory is None: + self._cache = CacheFactory(self._cache_config).get_cache() + else: + self._cache = self._cache_factory.get_cache() + + additional_startup_nodes = [ + ClusterNode(host, port) for host, port in additional_startup_nodes_info + ] # Convert to tuple to prevent RuntimeError if self.startup_nodes # is modified during iteration - for startup_node in tuple(self.startup_nodes.values()): + for startup_node in ( + *self.startup_nodes.values(), + *additional_startup_nodes, + ): try: if startup_node.redis_connection: r = startup_node.redis_connection @@ -1911,7 +2015,11 @@ def initialize(self): # Make sure cluster mode is enabled on this node try: cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) - r.connection_pool.disconnect() + if disconect_startup_nodes_pools: + # Disconnect the connection pool to avoid keeping the connection open + # For some cases we might not want to disconnect current pool and + # lose in flight commands responses + r.connection_pool.disconnect() except ResponseError: raise RedisClusterException( "Cluster mode is not enabled on this node" @@ -1992,12 +2100,6 @@ def initialize(self): f"one reachable node: {str(exception)}" ) from exception - if self._cache is None and self._cache_config is not None: - if self._cache_factory is None: - self._cache = CacheFactory(self._cache_config).get_cache() - else: - self._cache = self._cache_factory.get_cache() - # Create Redis connections to all nodes self.create_redis_connections(list(tmp_nodes_cache.values())) diff --git a/redis/connection.py b/redis/connection.py index 0a87777ac3..726c776c46 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -52,6 +52,7 @@ MaintNotificationsConfig, MaintNotificationsConnectionHandler, MaintNotificationsPoolHandler, + OSSMaintNotificationsHandler, ) from .retry import Retry from .utils import ( @@ -285,6 +286,9 @@ def __init__( orig_host_address: Optional[str] = None, orig_socket_timeout: Optional[float] = None, orig_socket_connect_timeout: Optional[float] = None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, ): """ @@ -298,6 +302,7 @@ def __init__( orig_host_address (Optional[str]): The original host address of the connection. orig_socket_timeout (Optional[float]): The original socket timeout of the connection. orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection. + oss_cluster_maint_notifications_handler (Optional[OSSMaintNotificationsHandler]): The OSS cluster handler for maintenance notifications. parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications. If not provided, the parser from the connection is used. This is useful when the parser is created after this object. @@ -310,6 +315,7 @@ def __init__( orig_host_address, orig_socket_timeout, orig_socket_connect_timeout, + oss_cluster_maint_notifications_handler, parser, ) @@ -386,6 +392,9 @@ def _configure_maintenance_notifications( orig_host_address=None, orig_socket_timeout=None, orig_socket_connect_timeout=None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, ): """ @@ -400,6 +409,7 @@ def _configure_maintenance_notifications( ): self._maint_notifications_pool_handler = None self._maint_notifications_connection_handler = None + self._oss_cluster_maint_notifications_handler = None return if not parser: @@ -427,11 +437,30 @@ def _configure_maintenance_notifications( else: self._maint_notifications_pool_handler = None + if oss_cluster_maint_notifications_handler: + # Extract a reference to a new handler that copies all properties + # of the original one and has a different connection reference + # This is needed because when we attach the handler to the parser + # we need to make sure that the handler has a reference to the + # connection that the parser is attached to. + self._oss_cluster_maint_notifications_handler = ( + oss_cluster_maint_notifications_handler.get_handler_for_connection() + ) + self._oss_cluster_maint_notifications_handler.set_connection(self) + else: + self._oss_cluster_maint_notifications_handler = None + self._maint_notifications_connection_handler = ( MaintNotificationsConnectionHandler(self, self.maint_notifications_config) ) - # Set up pool handler if available + # Set up OSS cluster handler to parser if available + if self._oss_cluster_maint_notifications_handler: + parser.set_oss_cluster_maint_push_handler( + self._oss_cluster_maint_notifications_handler.handle_notification + ) + + # Set up pool handler to parser if available if self._maint_notifications_pool_handler: parser.set_node_moving_push_handler( self._maint_notifications_pool_handler.handle_notification @@ -486,6 +515,41 @@ def set_maint_notifications_pool_handler_for_connection( maint_notifications_pool_handler.config ) + def set_maint_notifications_cluster_handler_for_connection( + self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler + ): + # Deep copy the cluster handler to avoid sharing the same handler + # between multiple connections, because otherwise each connection will override + # the connection reference and the handler will only hold a reference + # to the last connection that was set. + maint_notifications_cluster_handler_copy = ( + oss_cluster_maint_notifications_handler.get_handler_for_connection() + ) + + maint_notifications_cluster_handler_copy.set_connection(self) + self._get_parser().set_oss_cluster_maint_push_handler( + maint_notifications_cluster_handler_copy.handle_notification + ) + + self._oss_cluster_maint_notifications_handler = ( + maint_notifications_cluster_handler_copy + ) + + # Update maintenance notification connection handler if it doesn't exist + if not self._maint_notifications_connection_handler: + self._maint_notifications_connection_handler = ( + MaintNotificationsConnectionHandler( + self, oss_cluster_maint_notifications_handler.config + ) + ) + self._get_parser().set_maintenance_push_handler( + self._maint_notifications_connection_handler.handle_notification + ) + else: + self._maint_notifications_connection_handler.config = ( + oss_cluster_maint_notifications_handler.config + ) + def activate_maint_notifications_handling_if_enabled(self, check_health=True): # Send maintenance notifications handshake if RESP3 is active # and maintenance notifications are enabled @@ -688,6 +752,9 @@ def __init__( orig_host_address: Optional[str] = None, orig_socket_timeout: Optional[float] = None, orig_socket_connect_timeout: Optional[float] = None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, ): """ Initialize a new Connection. @@ -782,6 +849,7 @@ def __init__( orig_host_address, orig_socket_timeout, orig_socket_connect_timeout, + oss_cluster_maint_notifications_handler, self._parser, ) @@ -1352,6 +1420,7 @@ def __init__( self._conn.host, self._conn.socket_timeout, self._conn.socket_connect_timeout, + self._conn._oss_cluster_maint_notifications_handler, self._conn._get_parser(), ) @@ -1375,6 +1444,14 @@ def set_maint_notifications_pool_handler_for_connection( maint_notifications_pool_handler ) + def set_maint_notifications_cluster_handler_for_connection( + self, oss_cluster_maint_notifications_handler + ): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.set_maint_notifications_cluster_handler_for_connection( + oss_cluster_maint_notifications_handler + ) + def get_protocol(self): return self._conn.get_protocol() @@ -1564,6 +1641,19 @@ def socket_connect_timeout(self) -> Optional[Union[float, int]]: def socket_connect_timeout(self, value: Optional[Union[float, int]]): self._conn.socket_connect_timeout = value + @property + def _maint_notifications_connection_handler( + self, + ) -> Optional[MaintNotificationsConnectionHandler]: + if isinstance(self._conn, MaintNotificationsAbstractConnection): + return self._conn._maint_notifications_connection_handler + + @_maint_notifications_connection_handler.setter + def _maint_notifications_connection_handler( + self, value: Optional[MaintNotificationsAbstractConnection] + ): + self._conn._maint_notifications_connection_handler = value + def _get_socket(self) -> Optional[socket.socket]: if isinstance(self._conn, MaintNotificationsAbstractConnection): return self._conn._get_socket() @@ -2031,6 +2121,9 @@ class MaintNotificationsAbstractConnectionPool: def __init__( self, maint_notifications_config: Optional[MaintNotificationsConfig] = None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, **kwargs, ): # Initialize maintenance notifications @@ -2043,16 +2136,26 @@ def __init__( raise RedisError( "Maintenance notifications handlers on connection are only supported with RESP version 3" ) + if oss_cluster_maint_notifications_handler: + self._oss_cluster_maint_notifications_handler = ( + oss_cluster_maint_notifications_handler + ) + self._update_connection_kwargs_for_maint_notifications( + oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler + ) + self._maint_notifications_pool_handler = None + else: + self._oss_cluster_maint_notifications_handler = None + self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( + self, maint_notifications_config + ) - self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( - self, maint_notifications_config - ) - - self._update_connection_kwargs_for_maint_notifications( - self._maint_notifications_pool_handler - ) + self._update_connection_kwargs_for_maint_notifications( + maint_notifications_pool_handler=self._maint_notifications_pool_handler + ) else: self._maint_notifications_pool_handler = None + self._oss_cluster_maint_notifications_handler = None @property @abstractmethod @@ -2085,16 +2188,25 @@ def maint_notifications_enabled(self): The maintenance notifications config is stored in the pool handler. If the pool handler is not set, the maintenance notifications are not enabled. """ - maint_notifications_config = ( - self._maint_notifications_pool_handler.config - if self._maint_notifications_pool_handler - else None - ) + if self._oss_cluster_maint_notifications_handler: + maint_notifications_config = ( + self._oss_cluster_maint_notifications_handler.config + ) + else: + maint_notifications_config = ( + self._maint_notifications_pool_handler.config + if self._maint_notifications_pool_handler + else None + ) return maint_notifications_config and maint_notifications_config.enabled def update_maint_notifications_config( - self, maint_notifications_config: MaintNotificationsConfig + self, + maint_notifications_config: MaintNotificationsConfig, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, ): """ Updates the maintenance notifications configuration. @@ -2110,37 +2222,59 @@ def update_maint_notifications_config( raise ValueError( "Cannot disable maintenance notifications after enabling them" ) - # first update pool settings - if not self._maint_notifications_pool_handler: - self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( - self, maint_notifications_config + if oss_cluster_maint_notifications_handler: + self._oss_cluster_maint_notifications_handler = ( + oss_cluster_maint_notifications_handler ) else: - self._maint_notifications_pool_handler.config = maint_notifications_config + # first update pool settings + if not self._maint_notifications_pool_handler: + self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( + self, maint_notifications_config + ) + else: + self._maint_notifications_pool_handler.config = ( + maint_notifications_config + ) # then update connection kwargs and existing connections self._update_connection_kwargs_for_maint_notifications( - self._maint_notifications_pool_handler + maint_notifications_pool_handler=self._maint_notifications_pool_handler, + oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler, ) self._update_maint_notifications_configs_for_connections( - self._maint_notifications_pool_handler + maint_notifications_pool_handler=self._maint_notifications_pool_handler, + oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler, ) def _update_connection_kwargs_for_maint_notifications( - self, maint_notifications_pool_handler: MaintNotificationsPoolHandler + self, + maint_notifications_pool_handler: Optional[ + MaintNotificationsPoolHandler + ] = None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, ): """ Update the connection kwargs for all future connections. """ if not self.maint_notifications_enabled(): return - - self.connection_kwargs.update( - { - "maint_notifications_pool_handler": maint_notifications_pool_handler, - "maint_notifications_config": maint_notifications_pool_handler.config, - } - ) + if maint_notifications_pool_handler: + self.connection_kwargs.update( + { + "maint_notifications_pool_handler": maint_notifications_pool_handler, + "maint_notifications_config": maint_notifications_pool_handler.config, + } + ) + if oss_cluster_maint_notifications_handler: + self.connection_kwargs.update( + { + "oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler, + "maint_notifications_config": oss_cluster_maint_notifications_handler.config, + } + ) # Store original connection parameters for maintenance notifications. if self.connection_kwargs.get("orig_host_address", None) is None: @@ -2159,25 +2293,56 @@ def _update_connection_kwargs_for_maint_notifications( ) def _update_maint_notifications_configs_for_connections( - self, maint_notifications_pool_handler: MaintNotificationsPoolHandler + self, + maint_notifications_pool_handler: Optional[ + MaintNotificationsPoolHandler + ] = None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, ): """Update the maintenance notifications config for all connections in the pool.""" with self._get_pool_lock(): for conn in self._get_free_connections(): - conn.set_maint_notifications_pool_handler_for_connection( - maint_notifications_pool_handler - ) - conn.maint_notifications_config = ( - maint_notifications_pool_handler.config - ) + if oss_cluster_maint_notifications_handler: + ## set cluster handler for conn + conn.set_maint_notifications_cluster_handler_for_connection( + oss_cluster_maint_notifications_handler + ) + conn.maint_notifications_config = ( + oss_cluster_maint_notifications_handler.config + ) + elif maint_notifications_pool_handler: + conn.set_maint_notifications_pool_handler_for_connection( + maint_notifications_pool_handler + ) + conn.maint_notifications_config = ( + maint_notifications_pool_handler.config + ) + else: + raise ValueError( + "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set" + ) conn.disconnect() for conn in self._get_in_use_connections(): - conn.set_maint_notifications_pool_handler_for_connection( - maint_notifications_pool_handler - ) - conn.maint_notifications_config = ( - maint_notifications_pool_handler.config - ) + if oss_cluster_maint_notifications_handler: + conn.maint_notifications_config = ( + oss_cluster_maint_notifications_handler.config + ) + conn._configure_maintenance_notifications( + oss_cluster_maint_notifications_handler=oss_cluster_maint_notifications_handler + ) + elif maint_notifications_pool_handler: + conn.set_maint_notifications_pool_handler_for_connection( + maint_notifications_pool_handler + ) + conn.maint_notifications_config = ( + maint_notifications_pool_handler.config + ) + else: + raise ValueError( + "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set" + ) conn.mark_for_reconnect() def _should_update_connection( diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index dbb97d37b9..b161839493 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -10,7 +10,7 @@ from redis.typing import Number if TYPE_CHECKING: - from redis.cluster import NodesManager + from redis.cluster import MaintNotificationsAbstractRedisCluster class MaintenanceState(enum.Enum): @@ -463,8 +463,8 @@ class OSSNodeMigratedNotification(MaintenanceNotification): Args: id (int): Unique identifier for this notification - node_address (Optional[str]): Address of the node that has - completed migration - this is the destination node. + node_address (Optional[str]): Address of the node that has completed migration + in the format "host:port" slots (Optional[List[int]]): List of slots that have been migrated """ @@ -473,7 +473,7 @@ class OSSNodeMigratedNotification(MaintenanceNotification): def __init__( self, id: int, - node_address: Optional[str] = None, + node_address: str, slots: Optional[List[int]] = None, ): super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL) @@ -935,12 +935,30 @@ def handle_maintenance_completed_notification(self): class OSSMaintNotificationsHandler: def __init__( - self, nodes_manager: "NodesManager", config: MaintNotificationsConfig + self, + cluster_client: "MaintNotificationsAbstractRedisCluster", + config: MaintNotificationsConfig, ) -> None: - self.nodes_manager = nodes_manager + self.cluster_client = cluster_client self.config = config self._processed_notifications = set() + self._in_progress = set() self._lock = threading.RLock() + self.connection = None + + def set_connection(self, connection: "MaintNotificationsAbstractConnection"): + self.connection = connection + + def get_handler_for_connection(self): + # Copy all data that should be shared between connections + # but each connection should have its own pool handler + # since each connection can be in a different state + copy = OSSMaintNotificationsHandler(self.cluster_client, self.config) + copy._processed_notifications = self._processed_notifications + copy._in_progress = self._in_progress + copy._lock = self._lock + copy.connection = None + return copy def remove_expired_notifications(self): with self._lock: @@ -958,4 +976,58 @@ def handle_oss_maintenance_completed_notification( self, notification: OSSNodeMigratedNotification ): self.remove_expired_notifications() - logging.info(f"Received OSS maintenance completed notification: {notification}") + + with self._lock: + if ( + notification in self._in_progress + or notification in self._processed_notifications + ): + # we are already handling this notification or it has already been processed + # we should skip in_progress notification since when we reinitialize the cluster + # we execute a CLUSTER SLOTS command that can use a different connection + # that has also has the notification and we don't want to + # process the same notification twice + return + + self._in_progress.add(notification) + + # get the node to which the connection is connected + # before refreshing the cluster topology + current_node = self.cluster_client.nodes_manager.get_node( + host=self.connection.host, port=self.connection.port + ) + + # Updates the cluster slots cache with the new slots mapping + # This will also update the nodes cache with the new nodes mapping + new_node_host, new_node_port = notification.node_address.split(":") + self.cluster_client.nodes_manager.initialize( + disconect_startup_nodes_pools=False, + additional_startup_nodes_info=[(new_node_host, int(new_node_port))], + ) + + if ( + current_node + not in self.cluster_client.nodes_manager.nodes_cache.values() + ): + # disconnect all free connections to the node + for conn in current_node.redis_connection.connection_pool._get_free_connections(): + conn.disconnect() + # mark for reconnect all in use connections to the node - this will force them to + # disconnect after they complete their current commands + for conn in current_node.redis_connection.connection_pool._get_in_use_connections(): + conn.mark_for_reconnect() + else: + if self.config.is_relaxed_timeouts_enabled(): + # reset the timeouts for the node to which the connection is connected + # TODO: add check if other maintenance ops are in progress for the same node - CAE-1038 + # and if so, don't reset the timeouts + for conn in ( + *current_node.redis_connection.connection_pool._get_in_use_connections(), + *current_node.redis_connection.connection_pool._get_free_connections(), + ): + conn.update_current_socket_timeout(relaxed_timeout=-1) + conn.maintenance_state = MaintenanceState.NONE + + # mark the notification as processed + self._processed_notifications.add(notification) + self._in_progress.remove(notification) diff --git a/tests/maint_notifications/proxy_server_helpers.py b/tests/maint_notifications/proxy_server_helpers.py index b3397beefa..40dfd270b9 100644 --- a/tests/maint_notifications/proxy_server_helpers.py +++ b/tests/maint_notifications/proxy_server_helpers.py @@ -1,11 +1,10 @@ import base64 +from dataclasses import dataclass import logging import re -from typing import Union +from typing import Optional, Union from redis.http.http_client import HttpClient, HttpError -# from urllib.request import Request, urlopen -# from urllib.error import URLError class RespTranslator: @@ -34,7 +33,7 @@ def cluster_slots_to_resp(resp: str) -> str: ) @staticmethod - def smigrating_to_resp(resp: str) -> str: + def oss_maint_notification_to_resp(resp: str) -> str: """Convert query to RESP format.""" return ( f">{len(resp.split())}\r\n" @@ -45,6 +44,14 @@ def smigrating_to_resp(resp: str) -> str: ) +@dataclass +class SlotsRange: + host: str + port: int + start_slot: int + end_slot: int + + class ProxyInterceptorHelper: """Helper class for intercepting socket calls and managing interceptor server.""" @@ -52,6 +59,7 @@ def __init__(self, server_url: str = "http://localhost:4000"): self.server_url = server_url self._resp_translator = RespTranslator() self.http_client = HttpClient() + self._interceptors = list() def cleanup_interceptors(self, *names: str): """ @@ -60,56 +68,67 @@ def cleanup_interceptors(self, *names: str): Args: names: Names of the interceptors to reset """ - for name in names: + if not names: + names = self._interceptors + for name in tuple(names): self._reset_interceptor(name) - def set_cluster_nodes(self, name: str, nodes: list[tuple[str, int]]) -> str: + def set_cluster_slots( + self, + name: str, + slots_ranges: list[SlotsRange], + ) -> str: """ - Set cluster nodes by intercepting CLUSTER SLOTS command. + Set cluster slots and nodes by intercepting CLUSTER SLOTS command. This method creates an interceptor that intercepts CLUSTER SLOTS commands - and returns a modified topology with the provided nodes. + and returns a modified topology with the provided data. Args: name: Name of the interceptor - nodes: List of (host, port) tuples representing the cluster nodes + slots_ranges: List of SlotsRange objects representing the cluster + nodes and slots coverage Returns: The interceptor name that was created Example: interceptor = ProxyInterceptorHelper(None, "http://localhost:4000") - interceptor_name = interceptor.set_cluster_nodes( + interceptor.set_cluster_slots( "test_topology", - [("127.0.0.1", 6379), ("127.0.0.1", 6380), ("127.0.0.1", 6381)] + [ + SlotsRange("127.0.0.1", 6379, 0, 5000), + SlotsRange("127.0.0.1", 6380, 5001, 10000), + SlotsRange("127.0.0.1", 6381, 10001, 16383), + ] ) """ # Build RESP response for CLUSTER SLOTS # Format: * for each range: *3 :start :end *3 $ : $ - resp_parts = [f"*{len(nodes)}"] - - # For simplicity, distribute slots evenly across nodes - total_slots = 16384 - slots_per_node = total_slots // len(nodes) - - for i, (host, port) in enumerate(nodes): - start_slot = i * slots_per_node - end_slot = ( - (i + 1) * slots_per_node - 1 if i < len(nodes) - 1 else total_slots - 1 - ) + resp_parts = [f"*{len(slots_ranges)}"] + for slots_range in slots_ranges: # Node info: *3 for (host, port, id) resp_parts.append("*3") - resp_parts.append(f":{start_slot}") - resp_parts.append(f":{end_slot}") - - # Node details: *3 for (host, port, id) - resp_parts.append("*3") - resp_parts.append(f"${len(host)}") - resp_parts.append(host) - resp_parts.append(f":{port}") - resp_parts.append("$13") - resp_parts.append(f"proxy-id-{port}") + # 1st elem --> start slot + resp_parts.append(f":{slots_range.start_slot}") + # 2nd elem --> end slot + resp_parts.append(f":{slots_range.end_slot}") + + # 3rd elem --> list with node details: *4 for (host, port, id, empty hash) + resp_parts.append("*4") + # 1st elem --> host + resp_parts.append(f"${len(slots_range.host)}") + resp_parts.append(f"{slots_range.host}") + # 2nd elem --> port + resp_parts.append(f":{slots_range.port}") + # 3rd elem --> node id + node_id = f"proxy-id-{slots_range.port}" + resp_parts.append(f"${len(node_id)}") + resp_parts.append(node_id) + # 4th elem --> empty hash + resp_parts.append("$0") + resp_parts.append("") response = "\r\n".join(resp_parts) + "\r\n" @@ -257,7 +276,10 @@ def _add_interceptor( proxy_response = self.http_client.post( url, json_body=payload, headers=headers ) - return proxy_response.json() + self._interceptors.append(name) + if isinstance(proxy_response, dict): + return proxy_response + return proxy_response.json() if proxy_response else {} except HttpError as e: raise RuntimeError(f"Failed to add interceptor: {e}") @@ -268,4 +290,4 @@ def _reset_interceptor(self, name: str): Args: name: Name of the interceptor to reset """ - self._add_interceptor(name, "", "") + self._add_interceptor(name, "no_match", "") diff --git a/tests/maint_notifications/test_cluster_maint_notifications_handling.py b/tests/maint_notifications/test_cluster_maint_notifications_handling.py index 8e8d53b62f..97e91d2a05 100644 --- a/tests/maint_notifications/test_cluster_maint_notifications_handling.py +++ b/tests/maint_notifications/test_cluster_maint_notifications_handling.py @@ -11,12 +11,15 @@ from tests.maint_notifications.proxy_server_helpers import ( ProxyInterceptorHelper, RespTranslator, + SlotsRange, ) NODE_PORT_1 = 15379 NODE_PORT_2 = 15380 NODE_PORT_3 = 15381 +NODE_PORT_NEW = 15382 + # Initial cluster node configuration for proxy-based tests PROXY_CLUSTER_NODES = [ ClusterNode("127.0.0.1", NODE_PORT_1), @@ -34,6 +37,7 @@ def _create_cluster_client( enable_cache=False, max_connections=10, maint_config=None, + protocol=3, ) -> RedisCluster: """Create a RedisCluster instance with mocked sockets.""" if maint_config is None and hasattr(self, "config") and self.config is not None: @@ -44,7 +48,7 @@ def _create_cluster_client( kwargs = {"cache_config": CacheConfig()} test_redis_client = RedisCluster( - protocol=3, + protocol=protocol, startup_nodes=PROXY_CLUSTER_NODES, maint_notifications_config=maint_config, connection_pool_class=pool_class, @@ -214,14 +218,37 @@ def test_config_with_cache_enabled(self): def test_none_config_default_behavior(self): """ - Test that when maint_notifications_config=None, the system works without errors. + Test that when maint_notifications_config=None, it will be initialized with default values. """ cluster = self._create_cluster_client(maint_config=None) try: # Verify cluster is created successfully assert cluster.nodes_manager is not None + # for protocol 3, maint_notifications_config should be initialized with default values + assert cluster.nodes_manager.maint_notifications_config is not None + assert cluster.nodes_manager.maint_notifications_config.enabled == "auto" + assert len(cluster.nodes_manager.nodes_cache) > 0 + # Verify we can execute commands without errors + cluster.set("test", "VAL") + res = cluster.get("test") + assert res == b"VAL" + finally: + cluster.close() + + def test_none_config_default_behavior_for_protocol_2(self): + """ + Test that when maint_notifications_config=None and protocol=2, + it will not be initialized. + """ + cluster = self._create_cluster_client(protocol=2) + + try: + # Verify cluster is created successfully + assert cluster.nodes_manager is not None + # for protocol 2, maint_notifications_config should not be created assert cluster.nodes_manager.maint_notifications_config is None + assert len(cluster.nodes_manager.nodes_cache) > 0 # Verify we can execute commands without errors cluster.set("test", "VAL") @@ -285,6 +312,95 @@ def test_config_with_pipeline_operations(self): cluster.close() +class TestClusterMaintNotificationsHandler(TestClusterMaintNotificationsBase): + """Test OSSMaintNotificationsHandler propagation with RedisCluster.""" + + def _validate_connection_handlers( + self, conn, cluster_client, config, is_cache_conn=False + ): + """Helper method to validate connection handlers are properly set.""" + # Test that the oss cluster handler function is correctly set + oss_cluster_parser_handler_set_for_con = ( + conn._parser.oss_cluster_maint_push_handler_func + ) + assert oss_cluster_parser_handler_set_for_con is not None + assert hasattr(oss_cluster_parser_handler_set_for_con, "__self__") + assert hasattr(oss_cluster_parser_handler_set_for_con, "__func__") + assert oss_cluster_parser_handler_set_for_con.__self__.connection is conn + assert ( + oss_cluster_parser_handler_set_for_con.__self__.cluster_client + is cluster_client + ) + assert ( + oss_cluster_parser_handler_set_for_con.__self__._lock + is cluster_client._oss_cluster_maint_notifications_handler._lock + ) + assert ( + oss_cluster_parser_handler_set_for_con.__self__._processed_notifications + is cluster_client._oss_cluster_maint_notifications_handler._processed_notifications + ) + assert ( + oss_cluster_parser_handler_set_for_con.__func__ + is cluster_client._oss_cluster_maint_notifications_handler.handle_notification.__func__ + ) + + # Test that the maintenance handler function is correctly set + parser_maint_handler_set_for_con = conn._parser.maintenance_push_handler_func + assert parser_maint_handler_set_for_con is not None + assert hasattr(parser_maint_handler_set_for_con, "__self__") + assert hasattr(parser_maint_handler_set_for_con, "__func__") + # The maintenance handler should be bound to the connection's + # maintenance notification connection handler + assert ( + parser_maint_handler_set_for_con.__self__ + is conn._maint_notifications_connection_handler + ) + assert ( + parser_maint_handler_set_for_con.__func__ + is conn._maint_notifications_connection_handler.handle_notification.__func__ + ) + + # Validate that the connection's maintenance handler has the same config object + assert conn._maint_notifications_connection_handler.config is config + + def test_oss_maint_handler_propagation(self): + """Test that OSSMaintNotificationsHandler is propagated to all connections.""" + cluster = self._create_cluster_client() + # Verify all nodes have the handler + for node in cluster.nodes_manager.nodes_cache.values(): + assert node.redis_connection is not None + assert node.redis_connection.connection_pool is not None + for conn in ( + *node.redis_connection.connection_pool._get_in_use_connections(), + *node.redis_connection.connection_pool._get_free_connections(), + ): + assert conn._oss_cluster_maint_notifications_handler is not None + assert conn._oss_cluster_maint_notifications_handler.connection is conn + self._validate_connection_handlers( + conn, cluster, cluster.maint_notifications_config + ) + + def test_oss_maint_handler_propagation_cache_enabled(self): + """Test that OSSMaintNotificationsHandler is propagated to all connections.""" + cluster = self._create_cluster_client(enable_cache=True) + # Verify all nodes have the handler + for node in cluster.nodes_manager.nodes_cache.values(): + assert node.redis_connection is not None + assert node.redis_connection.connection_pool is not None + for conn in ( + *node.redis_connection.connection_pool._get_in_use_connections(), + *node.redis_connection.connection_pool._get_free_connections(), + ): + assert conn._conn._oss_cluster_maint_notifications_handler is not None + assert ( + conn._conn._oss_cluster_maint_notifications_handler.connection + is conn._conn + ) + self._validate_connection_handlers( + conn._conn, cluster, cluster.maint_notifications_config + ) + + class TestClusterMaintNotificationsHandlingBase(TestClusterMaintNotificationsBase): """Base class for maintenance notifications handling tests.""" @@ -317,6 +433,21 @@ class ConnectionStateExpectation: class TestClusterMaintNotificationsHandling(TestClusterMaintNotificationsHandlingBase): """Test maintenance notifications handling with RedisCluster.""" + def _warm_up_connection_pools( + self, cluster: RedisCluster, created_connections_count: int = 3 + ): + """Warm up connection pools by getting a connection from each pool.""" + for node in cluster.nodes_manager.nodes_cache.values(): + node_connections = [] + for _ in range(created_connections_count): + node_connections.append( + node.redis_connection.connection_pool.get_connection() + ) + for conn in node_connections: + node.redis_connection.connection_pool.release(conn) + + node_connections.clear() + def _get_expected_node_state( self, expectations_list: List[ConnectionStateExpectation], node_port: int ) -> Optional[ConnectionStateExpectation]: @@ -361,23 +492,26 @@ def _validate_connections_states( changed_connections_count += 1 assert changed_connections_count == expected_state.changed_connections_count - def test_receive_oss_maintenance_notification(self): - """Test receiving an OSS maintenance notification.""" - # get three connections from each node - for node in self.cluster.nodes_manager.nodes_cache.values(): - node_connections = [] - for _ in range(3): - node_connections.append( - node.redis_connection.connection_pool.get_connection() - ) - for conn in node_connections: - node.redis_connection.connection_pool.release(conn) + def _validate_removed_node_connections(self, node): + """Validate connections in a removed node.""" + assert node.redis_connection is not None + connection_pool = node.redis_connection.connection_pool + assert connection_pool is not None - node_connections.clear() + # validate all connections are disconnected or marked for reconnect + for conn in connection_pool._get_free_connections(): + assert conn._sock is None + for conn in connection_pool._get_in_use_connections(): + assert conn.should_reconnect() + + def test_receive_smigrating_notification(self): + """Test receiving an OSS maintenance notification.""" + # warm up connection pools + self._warm_up_connection_pools(self.cluster, created_connections_count=3) # send a notification to node 1 - notification = RespTranslator.smigrating_to_resp( - "SMIGRATING 12 TO 127.0.0.1:15380 <123,456,5000-7000>" + notification = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 <123,456,5000-7000>" ) self.proxy_helper.send_notification(NODE_PORT_1, notification) @@ -420,24 +554,424 @@ def test_receive_oss_maintenance_notification(self): ], ) - def test_receive_maint_notification(self): - """Test receiving a maintenance notification.""" - self.cluster.set("test", "VAL") - pubsub = self.cluster.pubsub() - pubsub.subscribe("test") - test_msg = pubsub.get_message(ignore_subscribe_messages=True, timeout=10) - print(test_msg) + def test_receive_smigrating_with_disabled_relaxed_timeout(self): + """Test receiving an OSS maintenance notification with disabled relaxed timeout.""" + # Create config with disabled relaxed timeout + disabled_config = MaintNotificationsConfig( + enabled="auto", + relaxed_timeout=-1, # This means the relaxed timeout is Disabled + ) + cluster = self._create_cluster_client(maint_config=disabled_config) + + # warm up connection pools + self._warm_up_connection_pools(cluster, created_connections_count=3) + + # send a notification to node 1 + notification = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 <123,456,5000-7000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, notification) + + # validate no timeout is relaxed on any connection + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, changed_connections_count=0 + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, changed_connections_count=0 + ), + ConnectionStateExpectation( + node_port=NODE_PORT_3, changed_connections_count=0 + ), + ], + ) + + def test_receive_smigrated_notification(self): + """Test receiving an OSS maintenance completed notification.""" + # create three connections in each node's connection pool + self._warm_up_connection_pools(self.cluster, created_connections_count=3) + + self.proxy_helper.set_cluster_slots( + "test_topology", + [ + SlotsRange("0.0.0.0", NODE_PORT_NEW, 0, 5460), + SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 10922), + SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + ], + ) + # send a notification to node 1 + notification = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 12 127.0.0.1:15380 <123,456,5000-7000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, notification) + + # execute a command that will receive the notification + res = self.cluster.set("anyprefix:{3}:k", "VAL") + assert res is True + + # validate the cluster topology was updated + new_node = self.cluster.nodes_manager.get_node( + host="0.0.0.0", port=NODE_PORT_NEW + ) + assert new_node is not None + + def test_smigrating_smigrated_on_two_nodes_without_node_replacement(self): + """Test receiving an OSS maintenance notification on two nodes without node replacement.""" + # warm up connection pools - create several connections in each pool + self._warm_up_connection_pools(self.cluster, created_connections_count=3) + + node_1 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_1) + node_2 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_2) + + smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 <123,2000-3000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, changed_connections_count=0 + ), + ], + ) + + smigrating_node_2 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 13 <8000-9000>" + ) + self.proxy_helper.send_notification(NODE_PORT_2, smigrating_node_2) + + # execute command with node 2 connection + self.cluster.set("anyprefix:{1}:k", "VAL") + + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) + smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 14 0.0.0.0:15381 <123,2000-3000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) + + smigrated_node_2 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 15 0.0.0.0:15381 <8000-9000>" + ) + self.proxy_helper.send_notification(NODE_PORT_2, smigrated_node_2) + + self.proxy_helper.set_cluster_slots( + "test_topology", + [ + SlotsRange("0.0.0.0", NODE_PORT_1, 0, 122), + SlotsRange("0.0.0.0", NODE_PORT_3, 123, 123), + SlotsRange("0.0.0.0", NODE_PORT_1, 124, 2000), + SlotsRange("0.0.0.0", NODE_PORT_3, 2001, 3000), + SlotsRange("0.0.0.0", NODE_PORT_1, 3001, 5460), + SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 10922), + SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + ], + ) + + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + + # validate the cluster topology was updated + # validate old nodes are there + assert node_1 in self.cluster.nodes_manager.nodes_cache.values() + assert node_2 in self.cluster.nodes_manager.nodes_cache.values() + # validate changed slot is assigned to node 3 + assert self.cluster.nodes_manager.get_node_from_slot( + 123 + ) == self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_3) + # validate the connections are in the correct state + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=0, + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) + + self.proxy_helper.set_cluster_slots( + "test_topology", + [ + SlotsRange("0.0.0.0", NODE_PORT_1, 0, 122), + SlotsRange("0.0.0.0", NODE_PORT_3, 123, 123), + SlotsRange("0.0.0.0", NODE_PORT_1, 124, 2000), + SlotsRange("0.0.0.0", NODE_PORT_3, 2001, 3000), + SlotsRange("0.0.0.0", NODE_PORT_1, 3001, 5460), + SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 7000), + SlotsRange("0.0.0.0", NODE_PORT_3, 7001, 8000), + SlotsRange("0.0.0.0", NODE_PORT_2, 8001, 10922), + SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + ], + ) + # execute command with node 2 connection + self.cluster.set("anyprefix:{1}:k", "VAL") + + # validate old nodes are there + assert node_1 in self.cluster.nodes_manager.nodes_cache.values() + assert node_2 in self.cluster.nodes_manager.nodes_cache.values() + # validate slot changes are reflected + assert self.cluster.nodes_manager.get_node_from_slot( + 8000 + ) == self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_3) + + # validate the connections are in the correct state + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=0, + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, + changed_connections_count=0, + ), + ], + ) + + def test_smigrating_smigrated_on_two_nodes_with_node_replacements(self): + """Test receiving an OSS maintenance notification on two nodes with node replacement.""" + # warm up connection pools - create several connections in each pool + self._warm_up_connection_pools(self.cluster, created_connections_count=3) + + node_1 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_1) + node_2 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_2) + node_3 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_3) + + smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 <0-5460>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, changed_connections_count=0 + ), + ], + ) - # Try to send a push notification to the clients of given server node - # Server node is defined by its port with the local test environment - # The message should be in the format: - # >3\r\n$7\r\nmessage\r\n$3\r\nfoo\r\n$4\r\neeee\r - notification = RespTranslator.smigrating_to_resp( - "TEST_NOTIFICATION 12182 127.0.0.1:15380" + smigrating_node_2 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 13 <5461-10922>" ) - self.proxy_helper.send_notification(pubsub.connection.port, notification) - res = self.proxy_helper.get_connections() - print(res) + self.proxy_helper.send_notification(NODE_PORT_2, smigrating_node_2) - test_msg = pubsub.get_message(timeout=1) - print(test_msg) + # execute command with node 2 connection + self.cluster.set("anyprefix:{1}:k", "VAL") + + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) + + smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 14 0.0.0.0:15382 <0-5460>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) + + smigrated_node_2 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 15 0.0.0.0:15382 <5461-10922>" + ) + self.proxy_helper.send_notification(NODE_PORT_2, smigrated_node_2) + self.proxy_helper.set_cluster_slots( + "test_topology", + [ + SlotsRange("0.0.0.0", 15382, 0, 5460), + SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 10922), + SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + ], + ) + + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + + # validate node 1 is removed + assert node_1 not in self.cluster.nodes_manager.nodes_cache.values() + # validate node 2 is still there + assert node_2 in self.cluster.nodes_manager.nodes_cache.values() + # validate new node is added + new_node = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=15382) + assert new_node is not None + assert new_node.redis_connection is not None + # validate a slot from the changed range is assigned to the new node + assert self.cluster.nodes_manager.get_node_from_slot( + 123 + ) == self.cluster.nodes_manager.get_node(host="0.0.0.0", port=15382) + + # validate the connections are in the correct state + self._validate_removed_node_connections(node_1) + + # validate the connections are in the correct state + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_2, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) + + self.proxy_helper.set_cluster_slots( + "test_topology", + [ + SlotsRange("0.0.0.0", 15382, 0, 5460), + SlotsRange("0.0.0.0", 15383, 5461, 10922), + SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + ], + ) + # execute command with node 2 connection + self.cluster.set("anyprefix:{1}:k", "VAL") + + # validate node 2 is removed + assert node_2 not in self.cluster.nodes_manager.nodes_cache.values() + # validate node 3 is still there + assert node_3 in self.cluster.nodes_manager.nodes_cache.values() + # validate new node is added + new_node = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=15383) + assert new_node is not None + assert new_node.redis_connection is not None + # validate a slot from the changed range is assigned to the new node + assert self.cluster.nodes_manager.get_node_from_slot( + 8000 + ) == self.cluster.nodes_manager.get_node(host="0.0.0.0", port=15383) + + # validate the connections in removed node are in the correct state + self._validate_removed_node_connections(node_2) + + def test_smigrating_smigrated_on_the_same_node_two_slot_ranges( + self, + ): + """ + Test receiving an OSS maintenance notification on the same node twice. + The focus here is to validate that the timeouts are not unrelaxed if a second + migration is in progress + """ + # warm up connection pools - create several connections in each pool + self._warm_up_connection_pools(self.cluster, created_connections_count=1) + + smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 <1000-2000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) + + smigrating_node_1_2 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 13 <3000-4000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1_2) + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) + smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 14 0.0.0.0:15380 <1000-2000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + # this functionality is part of CAE-1038 and will be added later + # validate the timeout is still relaxed + # self._validate_connections_states( + # self.cluster, + # [ + # ConnectionStateExpectation( + # node_port=NODE_PORT_1, + # changed_connections_count=1, + # state=MaintenanceState.MAINTENANCE, + # relaxed_timeout=self.config.relaxed_timeout, + # ), + # ], + # ) + smigrated_node_1_2 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 15 0.0.0.0:15381 <3000-4000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1_2) + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=0, + ), + ], + ) diff --git a/tests/maint_notifications/test_maint_notifications.py b/tests/maint_notifications/test_maint_notifications.py index 4ae5ae60a5..e4aec1f3a8 100644 --- a/tests/maint_notifications/test_maint_notifications.py +++ b/tests/maint_notifications/test_maint_notifications.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, call, patch, MagicMock import pytest +from redis.cluster import ClusterNode from redis.connection import ConnectionInterface, MaintNotificationsAbstractConnection from redis.maint_notifications import ( @@ -493,40 +494,44 @@ class TestOSSNodeMigratedNotification: def test_init_with_defaults(self): """Test OSSNodeMigratedNotification initialization with default values.""" with patch("time.monotonic", return_value=1000): - notification = OSSNodeMigratedNotification(id=1) + notification = OSSNodeMigratedNotification( + id=1, node_address="127.0.0.1:6380" + ) assert notification.id == 1 assert notification.ttl == OSSNodeMigratedNotification.DEFAULT_TTL assert notification.creation_time == 1000 - assert notification.node_address is None + assert notification.node_address == "127.0.0.1:6380" assert notification.slots is None def test_init_with_all_parameters(self): """Test OSSNodeMigratedNotification initialization with all parameters.""" with patch("time.monotonic", return_value=1000): slots = [1, 2, 3, 4, 5] + node_address = "127.0.0.1:6380" notification = OSSNodeMigratedNotification( id=1, - node_address="127.0.0.1:6380", + node_address=node_address, slots=slots, ) assert notification.id == 1 assert notification.ttl == OSSNodeMigratedNotification.DEFAULT_TTL assert notification.creation_time == 1000 - assert notification.node_address == "127.0.0.1:6380" + assert notification.node_address == node_address assert notification.slots == slots def test_default_ttl(self): """Test that DEFAULT_TTL is used correctly.""" assert OSSNodeMigratedNotification.DEFAULT_TTL == 30 - notification = OSSNodeMigratedNotification(id=1) + notification = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") assert notification.ttl == 30 def test_repr(self): """Test OSSNodeMigratedNotification string representation.""" with patch("time.monotonic", return_value=1000): + node_address = "127.0.0.1:6380" notification = OSSNodeMigratedNotification( id=1, - node_address="127.0.0.1:6380", + node_address=node_address, slots=[1, 2, 3], ) @@ -555,13 +560,13 @@ def test_equality_same_id_and_type(self): def test_equality_different_id(self): """Test inequality for notifications with different id.""" - notification1 = OSSNodeMigratedNotification(id=1) - notification2 = OSSNodeMigratedNotification(id=2) + notification1 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") + notification2 = OSSNodeMigratedNotification(id=2, node_address="127.0.0.1:6380") assert notification1 != notification2 def test_equality_different_type(self): """Test inequality for notifications of different types.""" - notification1 = OSSNodeMigratedNotification(id=1) + notification1 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") notification2 = NodeMigratedNotification(id=1) assert notification1 != notification2 @@ -582,16 +587,16 @@ def test_hash_same_id_and_type(self): def test_hash_different_id(self): """Test hash for notifications with different id.""" - notification1 = OSSNodeMigratedNotification(id=1) - notification2 = OSSNodeMigratedNotification(id=2) + notification1 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") + notification2 = OSSNodeMigratedNotification(id=2, node_address="127.0.0.1:6380") assert hash(notification1) != hash(notification2) def test_in_set(self): """Test that notifications can be used in sets.""" - notification1 = OSSNodeMigratedNotification(id=1) - notification2 = OSSNodeMigratedNotification(id=1) - notification3 = OSSNodeMigratedNotification(id=2) - notification4 = OSSNodeMigratedNotification(id=2) + notification1 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") + notification2 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") + notification3 = OSSNodeMigratedNotification(id=2, node_address="127.0.0.1:6381") + notification4 = OSSNodeMigratedNotification(id=2, node_address="127.0.0.1:6381") notification_set = {notification1, notification2, notification3, notification4} assert ( From 76172f993affe9e9bb4e85a411ad36448c8b521d Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 26 Nov 2025 10:49:22 +0200 Subject: [PATCH 2/3] Applying Copilot's comments --- redis/cluster.py | 8 ++++---- redis/connection.py | 4 ++-- redis/maint_notifications.py | 2 +- tests/maint_notifications/proxy_server_helpers.py | 2 +- tests/maint_notifications/test_maint_notifications.py | 1 - 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index 06c1e37f38..ed26188c64 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -251,7 +251,7 @@ def __init__( self._update_connection_kwargs_for_maint_notifications( self._oss_cluster_maint_notifications_handler ) - # Update existing nodes connections - they are created as part of the RedsiCluster constructor + # Update existing nodes connections - they are created as part of the RedisCluster constructor for node in self.get_nodes(): node.redis_connection.connection_pool.update_maint_notifications_config( self.maint_notifications_config, @@ -1952,13 +1952,13 @@ def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache): def initialize( self, additional_startup_nodes_info: List[Tuple[str, int]] = [], - disconect_startup_nodes_pools: bool = True, + disconnect_startup_nodes_pools: bool = True, ): """ Initializes the nodes cache, slots cache and redis connections. :startup_nodes: Responsible for discovering other nodes in the cluster - :disconect_startup_nodes_pools: + :disconnect_startup_nodes_pools: Whether to disconnect the connection pool of the startup nodes after the initialization is complete. This is useful when the startup nodes are not part of the cluster and we want to avoid @@ -2015,7 +2015,7 @@ def initialize( # Make sure cluster mode is enabled on this node try: cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) - if disconect_startup_nodes_pools: + if disconnect_startup_nodes_pools: # Disconnect the connection pool to avoid keeping the connection open # For some cases we might not want to disconnect current pool and # lose in flight commands responses diff --git a/redis/connection.py b/redis/connection.py index 726c776c46..1c16098098 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1650,7 +1650,7 @@ def _maint_notifications_connection_handler( @_maint_notifications_connection_handler.setter def _maint_notifications_connection_handler( - self, value: Optional[MaintNotificationsAbstractConnection] + self, value: Optional[MaintNotificationsConnectionHandler] ): self._conn._maint_notifications_connection_handler = value @@ -2305,7 +2305,7 @@ def _update_maint_notifications_configs_for_connections( with self._get_pool_lock(): for conn in self._get_free_connections(): if oss_cluster_maint_notifications_handler: - ## set cluster handler for conn + # set cluster handler for conn conn.set_maint_notifications_cluster_handler_for_connection( oss_cluster_maint_notifications_handler ) diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index b161839493..cea0da47bd 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -1001,7 +1001,7 @@ def handle_oss_maintenance_completed_notification( # This will also update the nodes cache with the new nodes mapping new_node_host, new_node_port = notification.node_address.split(":") self.cluster_client.nodes_manager.initialize( - disconect_startup_nodes_pools=False, + disconnect_startup_nodes_pools=False, additional_startup_nodes_info=[(new_node_host, int(new_node_port))], ) diff --git a/tests/maint_notifications/proxy_server_helpers.py b/tests/maint_notifications/proxy_server_helpers.py index 40dfd270b9..7358f078d8 100644 --- a/tests/maint_notifications/proxy_server_helpers.py +++ b/tests/maint_notifications/proxy_server_helpers.py @@ -2,7 +2,7 @@ from dataclasses import dataclass import logging import re -from typing import Optional, Union +from typing import Union from redis.http.http_client import HttpClient, HttpError diff --git a/tests/maint_notifications/test_maint_notifications.py b/tests/maint_notifications/test_maint_notifications.py index e4aec1f3a8..adb9ebb5ea 100644 --- a/tests/maint_notifications/test_maint_notifications.py +++ b/tests/maint_notifications/test_maint_notifications.py @@ -2,7 +2,6 @@ from unittest.mock import Mock, call, patch, MagicMock import pytest -from redis.cluster import ClusterNode from redis.connection import ConnectionInterface, MaintNotificationsAbstractConnection from redis.maint_notifications import ( From b0a2f9bcf01a0598adb324094002e497f3fb602c Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 27 Nov 2025 16:18:05 +0200 Subject: [PATCH 3/3] Applying review comments --- redis/client.py | 3 ++- redis/cluster.py | 26 ++++++++++++++++---------- redis/connection.py | 6 ++++-- redis/event.py | 3 ++- redis/utils.py | 8 ++++++++ 5 files changed, 32 insertions(+), 14 deletions(-) diff --git a/redis/client.py b/redis/client.py index 354ae3a68a..bc5e57b922 100755 --- a/redis/client.py +++ b/redis/client.py @@ -63,6 +63,7 @@ from redis.retry import Retry from redis.utils import ( _set_info_logger, + check_protocol_version, deprecated_args, get_lib_version, safe_str, @@ -366,7 +367,7 @@ def __init__( "ssl_ciphers": ssl_ciphers, } ) - if (cache_config or cache) and protocol in [3, "3"]: + if (cache_config or cache) and check_protocol_version(protocol, 3): kwargs.update( { "cache": cache, diff --git a/redis/cluster.py b/redis/cluster.py index ed26188c64..6892d71780 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -58,6 +58,7 @@ ) from redis.retry import Retry from redis.utils import ( + check_protocol_version, deprecated_args, dict_merge, list_keys_to_dict, @@ -233,19 +234,24 @@ def __init__( **kwargs, ): # Initialize maintenance notifications - is_protocol_supported = kwargs.get("protocol") in [3, "3"] + is_protocol_supported = check_protocol_version(kwargs.get("protocol"), 3) + + if ( + maint_notifications_config + and maint_notifications_config.enabled + and not is_protocol_supported + ): + raise RedisError( + "Maintenance notifications handlers on connection are only supported with RESP version 3" + ) if maint_notifications_config is None and is_protocol_supported: maint_notifications_config = MaintNotificationsConfig() self.maint_notifications_config = maint_notifications_config - if maint_notifications_config and maint_notifications_config.enabled: - if not is_protocol_supported: - raise RedisError( - "Maintenance notifications handlers on connection are only supported with RESP version 3" - ) + if self.maint_notifications_config and self.maint_notifications_config.enabled: self._oss_cluster_maint_notifications_handler = ( - OSSMaintNotificationsHandler(self, maint_notifications_config) + OSSMaintNotificationsHandler(self, self.maint_notifications_config) ) # Update connection kwargs for all future nodes connections self._update_connection_kwargs_for_maint_notifications( @@ -755,14 +761,14 @@ def __init__( kwargs.get("decode_responses", False), ) protocol = kwargs.get("protocol", None) - if (cache_config or cache) and protocol not in [3, "3"]: + if (cache_config or cache) and not check_protocol_version(protocol, 3): raise RedisError("Client caching is only supported with RESP version 3") - if maint_notifications_config and protocol not in [3, "3"]: + if maint_notifications_config and not check_protocol_version(protocol, 3): raise RedisError( "Maintenance notifications are only supported with RESP version 3" ) - if protocol in [3, "3"] and maint_notifications_config is None: + if check_protocol_version(protocol, 3) and maint_notifications_config is None: maint_notifications_config = MaintNotificationsConfig() self.command_flags = self.__class__.COMMAND_FLAGS.copy() diff --git a/redis/connection.py b/redis/connection.py index 1c16098098..c9a3221b0b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -59,6 +59,7 @@ CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, SSL_AVAILABLE, + check_protocol_version, compare_versions, deprecated_args, ensure_string, @@ -2127,7 +2128,8 @@ def __init__( **kwargs, ): # Initialize maintenance notifications - is_protocol_supported = kwargs.get("protocol") in [3, "3"] + is_protocol_supported = check_protocol_version(kwargs.get("protocol"), 3) + if maint_notifications_config is None and is_protocol_supported: maint_notifications_config = MaintNotificationsConfig() @@ -2615,7 +2617,7 @@ def __init__( self._cache_factory = cache_factory if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): - if self._connection_kwargs.get("protocol") not in [3, "3"]: + if not check_protocol_version(self._connection_kwargs.get("protocol"), 3): raise RedisError("Client caching is only supported with RESP version 3") cache = self._connection_kwargs.get("cache") diff --git a/redis/event.py b/redis/event.py index 03c72c6370..18a4b80af9 100644 --- a/redis/event.py +++ b/redis/event.py @@ -6,6 +6,7 @@ from redis.auth.token import TokenInterface from redis.credentials import CredentialProvider, StreamingCredentialProvider +from redis.utils import check_protocol_version class EventListenerInterface(ABC): @@ -427,7 +428,7 @@ def __init__(self): def listen(self, event: AfterPubSubConnectionInstantiationEvent): if isinstance( event.pubsub_connection.credential_provider, StreamingCredentialProvider - ) and event.pubsub_connection.get_protocol() in [3, "3"]: + ) and check_protocol_version(event.pubsub_connection.get_protocol(), 3): self._event = event self._connection = event.pubsub_connection self._connection_pool = event.connection_pool diff --git a/redis/utils.py b/redis/utils.py index 37a11a74f5..b134e1c103 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -213,6 +213,14 @@ def _set_info_logger(): logger.addHandler(handler) +def check_protocol_version( + protocol: Optional[Union[str, int]], expected_version: int = 3 +) -> bool: + if protocol is None: + return False + return int(protocol) == expected_version + + def get_lib_version(): try: libver = metadata.version("redis")