diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 2da7a94b3d..c5d4678022 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -14,6 +14,7 @@ OSSNodeMigratedNotification, OSSNodeMigratingNotification, ) +from redis.utils import safe_str if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout @@ -194,8 +195,9 @@ def parse_oss_maintenance_completed_msg(response): # Expected message format is: # SMIGRATED id = response[1] - node_address = response[2] + node_address = safe_str(response[2]) slots = response[3] + return OSSNodeMigratedNotification(id, node_address, slots) @staticmethod @@ -225,9 +227,7 @@ def parse_moving_msg(response): if response[3] is None: host, port = None, None else: - value = response[3] - if isinstance(value, bytes): - value = value.decode() + value = safe_str(response[3]) host, port = value.split(":") port = int(port) if port is not None else None diff --git a/redis/client.py b/redis/client.py index e2712fc3f8..bc5e57b922 100755 --- a/redis/client.py +++ b/redis/client.py @@ -58,10 +58,12 @@ from redis.lock import Lock from redis.maint_notifications import ( MaintNotificationsConfig, + OSSMaintNotificationsHandler, ) from redis.retry import Retry from redis.utils import ( _set_info_logger, + check_protocol_version, deprecated_args, get_lib_version, safe_str, @@ -250,6 +252,9 @@ def __init__( cache_config: Optional[CacheConfig] = None, event_dispatcher: Optional[EventDispatcher] = None, maint_notifications_config: Optional[MaintNotificationsConfig] = None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, ) -> None: """ Initialize a new Redis client. @@ -288,6 +293,11 @@ def __init__( will be enabled by default (logic is included in the connection pool initialization). Argument is ignored when connection_pool is provided. + oss_cluster_maint_notifications_handler: + handler for OSS cluster notifications - see + `redis.maint_notifications.OSSMaintNotificationsHandler` for details. + Only supported with RESP3 + Argument is ignored when connection_pool is provided. """ if event_dispatcher is None: self._event_dispatcher = EventDispatcher() @@ -357,7 +367,7 @@ def __init__( "ssl_ciphers": ssl_ciphers, } ) - if (cache_config or cache) and protocol in [3, "3"]: + if (cache_config or cache) and check_protocol_version(protocol, 3): kwargs.update( { "cache": cache, @@ -380,6 +390,12 @@ def __init__( "maint_notifications_config": maint_notifications_config, } ) + if oss_cluster_maint_notifications_handler: + kwargs.update( + { + "oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler, + } + ) connection_pool = ConnectionPool(**kwargs) self._event_dispatcher.dispatch( AfterPooledConnectionsInstantiationEvent( diff --git a/redis/cluster.py b/redis/cluster.py index f06000563a..6892d71780 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -52,9 +52,13 @@ WatchError, ) from redis.lock import Lock -from redis.maint_notifications import MaintNotificationsConfig +from redis.maint_notifications import ( + MaintNotificationsConfig, + OSSMaintNotificationsHandler, +) from redis.retry import Retry from redis.utils import ( + check_protocol_version, deprecated_args, dict_merge, list_keys_to_dict, @@ -214,6 +218,67 @@ def cleanup_kwargs(**kwargs): return connection_kwargs +class MaintNotificationsAbstractRedisCluster: + """ + Abstract class for handling maintenance notifications logic. + This class is expected to be used as base class together with RedisCluster. + + This class is intended to be used with multiple inheritance! + + All logic related to maintenance notifications is encapsulated in this class. + """ + + def __init__( + self, + maint_notifications_config: Optional[MaintNotificationsConfig], + **kwargs, + ): + # Initialize maintenance notifications + is_protocol_supported = check_protocol_version(kwargs.get("protocol"), 3) + + if ( + maint_notifications_config + and maint_notifications_config.enabled + and not is_protocol_supported + ): + raise RedisError( + "Maintenance notifications handlers on connection are only supported with RESP version 3" + ) + if maint_notifications_config is None and is_protocol_supported: + maint_notifications_config = MaintNotificationsConfig() + + self.maint_notifications_config = maint_notifications_config + + if self.maint_notifications_config and self.maint_notifications_config.enabled: + self._oss_cluster_maint_notifications_handler = ( + OSSMaintNotificationsHandler(self, self.maint_notifications_config) + ) + # Update connection kwargs for all future nodes connections + self._update_connection_kwargs_for_maint_notifications( + self._oss_cluster_maint_notifications_handler + ) + # Update existing nodes connections - they are created as part of the RedisCluster constructor + for node in self.get_nodes(): + node.redis_connection.connection_pool.update_maint_notifications_config( + self.maint_notifications_config, + oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler, + ) + else: + self._oss_cluster_maint_notifications_handler = None + + def _update_connection_kwargs_for_maint_notifications( + self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler + ): + """ + Update the connection kwargs for all future connections. + """ + self.nodes_manager.connection_kwargs.update( + { + "oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler, + } + ) + + class AbstractRedisCluster: RedisClusterRequestTTL = 16 @@ -461,7 +526,9 @@ def replace_default_node(self, target_node: "ClusterNode" = None) -> None: self.nodes_manager.default_node = random.choice(replicas) -class RedisCluster(AbstractRedisCluster, RedisClusterCommands): +class RedisCluster( + AbstractRedisCluster, MaintNotificationsAbstractRedisCluster, RedisClusterCommands +): @classmethod def from_url(cls, url, **kwargs): """ @@ -612,8 +679,7 @@ def __init__( `redis.maint_notifications.MaintNotificationsConfig` for details. Only supported with RESP3. If not provided and protocol is RESP3, the maintenance notifications - will be enabled by default (logic is included in the NodesManager - initialization). + will be enabled by default. :**kwargs: Extra arguments that will be sent into Redis instance when created (See Official redis-py doc for supported kwargs - the only limitation @@ -695,9 +761,16 @@ def __init__( kwargs.get("decode_responses", False), ) protocol = kwargs.get("protocol", None) - if (cache_config or cache) and protocol not in [3, "3"]: + if (cache_config or cache) and not check_protocol_version(protocol, 3): raise RedisError("Client caching is only supported with RESP version 3") + if maint_notifications_config and not check_protocol_version(protocol, 3): + raise RedisError( + "Maintenance notifications are only supported with RESP version 3" + ) + if check_protocol_version(protocol, 3) and maint_notifications_config is None: + maint_notifications_config = MaintNotificationsConfig() + self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.node_flags = self.__class__.NODE_FLAGS.copy() self.read_from_replicas = read_from_replicas @@ -709,6 +782,7 @@ def __init__( else: self._event_dispatcher = event_dispatcher self.startup_nodes = startup_nodes + self.nodes_manager = NodesManager( startup_nodes=startup_nodes, from_url=from_url, @@ -763,6 +837,10 @@ def __init__( self._aggregate_nodes = None self._lock = threading.RLock() + MaintNotificationsAbstractRedisCluster.__init__( + self, maint_notifications_config, **kwargs + ) + def __enter__(self): return self @@ -1632,9 +1710,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, cache_factory: Optional[CacheFactoryInterface] = None, event_dispatcher: Optional[EventDispatcher] = None, - maint_notifications_config: Optional[ - MaintNotificationsConfig - ] = MaintNotificationsConfig(), + maint_notifications_config: Optional[MaintNotificationsConfig] = None, **kwargs, ): self.nodes_cache: Dict[str, Redis] = {} @@ -1879,11 +1955,29 @@ def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache): return target_node - def initialize(self): + def initialize( + self, + additional_startup_nodes_info: List[Tuple[str, int]] = [], + disconnect_startup_nodes_pools: bool = True, + ): """ Initializes the nodes cache, slots cache and redis connections. :startup_nodes: Responsible for discovering other nodes in the cluster + :disconnect_startup_nodes_pools: + Whether to disconnect the connection pool of the startup nodes + after the initialization is complete. This is useful when the + startup nodes are not part of the cluster and we want to avoid + keeping the connection open. + :additional_startup_nodes_info: + Additional nodes to add temporarily to the startup nodes. + The additional nodes will be used just in the process of extraction of the slots + and nodes information from the cluster. + This is useful when we want to add new nodes to the cluster + and initialize the client + with them. + The format of the list is a list of tuples, where each tuple contains + the host and port of the node. """ self.reset() tmp_nodes_cache = {} @@ -1893,9 +1987,25 @@ def initialize(self): fully_covered = False kwargs = self.connection_kwargs exception = None + + # Create cache if it's not provided and cache config is set + # should be done before initializing the first connection + # so that it will be applied to all connections + if self._cache is None and self._cache_config is not None: + if self._cache_factory is None: + self._cache = CacheFactory(self._cache_config).get_cache() + else: + self._cache = self._cache_factory.get_cache() + + additional_startup_nodes = [ + ClusterNode(host, port) for host, port in additional_startup_nodes_info + ] # Convert to tuple to prevent RuntimeError if self.startup_nodes # is modified during iteration - for startup_node in tuple(self.startup_nodes.values()): + for startup_node in ( + *self.startup_nodes.values(), + *additional_startup_nodes, + ): try: if startup_node.redis_connection: r = startup_node.redis_connection @@ -1911,7 +2021,11 @@ def initialize(self): # Make sure cluster mode is enabled on this node try: cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) - r.connection_pool.disconnect() + if disconnect_startup_nodes_pools: + # Disconnect the connection pool to avoid keeping the connection open + # For some cases we might not want to disconnect current pool and + # lose in flight commands responses + r.connection_pool.disconnect() except ResponseError: raise RedisClusterException( "Cluster mode is not enabled on this node" @@ -1992,12 +2106,6 @@ def initialize(self): f"one reachable node: {str(exception)}" ) from exception - if self._cache is None and self._cache_config is not None: - if self._cache_factory is None: - self._cache = CacheFactory(self._cache_config).get_cache() - else: - self._cache = self._cache_factory.get_cache() - # Create Redis connections to all nodes self.create_redis_connections(list(tmp_nodes_cache.values())) diff --git a/redis/connection.py b/redis/connection.py index 0a87777ac3..c9a3221b0b 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -52,12 +52,14 @@ MaintNotificationsConfig, MaintNotificationsConnectionHandler, MaintNotificationsPoolHandler, + OSSMaintNotificationsHandler, ) from .retry import Retry from .utils import ( CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, SSL_AVAILABLE, + check_protocol_version, compare_versions, deprecated_args, ensure_string, @@ -285,6 +287,9 @@ def __init__( orig_host_address: Optional[str] = None, orig_socket_timeout: Optional[float] = None, orig_socket_connect_timeout: Optional[float] = None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, ): """ @@ -298,6 +303,7 @@ def __init__( orig_host_address (Optional[str]): The original host address of the connection. orig_socket_timeout (Optional[float]): The original socket timeout of the connection. orig_socket_connect_timeout (Optional[float]): The original socket connect timeout of the connection. + oss_cluster_maint_notifications_handler (Optional[OSSMaintNotificationsHandler]): The OSS cluster handler for maintenance notifications. parser (Optional[Union[_HiredisParser, _RESP3Parser]]): The parser to use for maintenance notifications. If not provided, the parser from the connection is used. This is useful when the parser is created after this object. @@ -310,6 +316,7 @@ def __init__( orig_host_address, orig_socket_timeout, orig_socket_connect_timeout, + oss_cluster_maint_notifications_handler, parser, ) @@ -386,6 +393,9 @@ def _configure_maintenance_notifications( orig_host_address=None, orig_socket_timeout=None, orig_socket_connect_timeout=None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, parser: Optional[Union[_HiredisParser, _RESP3Parser]] = None, ): """ @@ -400,6 +410,7 @@ def _configure_maintenance_notifications( ): self._maint_notifications_pool_handler = None self._maint_notifications_connection_handler = None + self._oss_cluster_maint_notifications_handler = None return if not parser: @@ -427,11 +438,30 @@ def _configure_maintenance_notifications( else: self._maint_notifications_pool_handler = None + if oss_cluster_maint_notifications_handler: + # Extract a reference to a new handler that copies all properties + # of the original one and has a different connection reference + # This is needed because when we attach the handler to the parser + # we need to make sure that the handler has a reference to the + # connection that the parser is attached to. + self._oss_cluster_maint_notifications_handler = ( + oss_cluster_maint_notifications_handler.get_handler_for_connection() + ) + self._oss_cluster_maint_notifications_handler.set_connection(self) + else: + self._oss_cluster_maint_notifications_handler = None + self._maint_notifications_connection_handler = ( MaintNotificationsConnectionHandler(self, self.maint_notifications_config) ) - # Set up pool handler if available + # Set up OSS cluster handler to parser if available + if self._oss_cluster_maint_notifications_handler: + parser.set_oss_cluster_maint_push_handler( + self._oss_cluster_maint_notifications_handler.handle_notification + ) + + # Set up pool handler to parser if available if self._maint_notifications_pool_handler: parser.set_node_moving_push_handler( self._maint_notifications_pool_handler.handle_notification @@ -486,6 +516,41 @@ def set_maint_notifications_pool_handler_for_connection( maint_notifications_pool_handler.config ) + def set_maint_notifications_cluster_handler_for_connection( + self, oss_cluster_maint_notifications_handler: OSSMaintNotificationsHandler + ): + # Deep copy the cluster handler to avoid sharing the same handler + # between multiple connections, because otherwise each connection will override + # the connection reference and the handler will only hold a reference + # to the last connection that was set. + maint_notifications_cluster_handler_copy = ( + oss_cluster_maint_notifications_handler.get_handler_for_connection() + ) + + maint_notifications_cluster_handler_copy.set_connection(self) + self._get_parser().set_oss_cluster_maint_push_handler( + maint_notifications_cluster_handler_copy.handle_notification + ) + + self._oss_cluster_maint_notifications_handler = ( + maint_notifications_cluster_handler_copy + ) + + # Update maintenance notification connection handler if it doesn't exist + if not self._maint_notifications_connection_handler: + self._maint_notifications_connection_handler = ( + MaintNotificationsConnectionHandler( + self, oss_cluster_maint_notifications_handler.config + ) + ) + self._get_parser().set_maintenance_push_handler( + self._maint_notifications_connection_handler.handle_notification + ) + else: + self._maint_notifications_connection_handler.config = ( + oss_cluster_maint_notifications_handler.config + ) + def activate_maint_notifications_handling_if_enabled(self, check_health=True): # Send maintenance notifications handshake if RESP3 is active # and maintenance notifications are enabled @@ -688,6 +753,9 @@ def __init__( orig_host_address: Optional[str] = None, orig_socket_timeout: Optional[float] = None, orig_socket_connect_timeout: Optional[float] = None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, ): """ Initialize a new Connection. @@ -782,6 +850,7 @@ def __init__( orig_host_address, orig_socket_timeout, orig_socket_connect_timeout, + oss_cluster_maint_notifications_handler, self._parser, ) @@ -1352,6 +1421,7 @@ def __init__( self._conn.host, self._conn.socket_timeout, self._conn.socket_connect_timeout, + self._conn._oss_cluster_maint_notifications_handler, self._conn._get_parser(), ) @@ -1375,6 +1445,14 @@ def set_maint_notifications_pool_handler_for_connection( maint_notifications_pool_handler ) + def set_maint_notifications_cluster_handler_for_connection( + self, oss_cluster_maint_notifications_handler + ): + if isinstance(self._conn, MaintNotificationsAbstractConnection): + self._conn.set_maint_notifications_cluster_handler_for_connection( + oss_cluster_maint_notifications_handler + ) + def get_protocol(self): return self._conn.get_protocol() @@ -1564,6 +1642,19 @@ def socket_connect_timeout(self) -> Optional[Union[float, int]]: def socket_connect_timeout(self, value: Optional[Union[float, int]]): self._conn.socket_connect_timeout = value + @property + def _maint_notifications_connection_handler( + self, + ) -> Optional[MaintNotificationsConnectionHandler]: + if isinstance(self._conn, MaintNotificationsAbstractConnection): + return self._conn._maint_notifications_connection_handler + + @_maint_notifications_connection_handler.setter + def _maint_notifications_connection_handler( + self, value: Optional[MaintNotificationsConnectionHandler] + ): + self._conn._maint_notifications_connection_handler = value + def _get_socket(self) -> Optional[socket.socket]: if isinstance(self._conn, MaintNotificationsAbstractConnection): return self._conn._get_socket() @@ -2031,10 +2122,14 @@ class MaintNotificationsAbstractConnectionPool: def __init__( self, maint_notifications_config: Optional[MaintNotificationsConfig] = None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, **kwargs, ): # Initialize maintenance notifications - is_protocol_supported = kwargs.get("protocol") in [3, "3"] + is_protocol_supported = check_protocol_version(kwargs.get("protocol"), 3) + if maint_notifications_config is None and is_protocol_supported: maint_notifications_config = MaintNotificationsConfig() @@ -2043,16 +2138,26 @@ def __init__( raise RedisError( "Maintenance notifications handlers on connection are only supported with RESP version 3" ) + if oss_cluster_maint_notifications_handler: + self._oss_cluster_maint_notifications_handler = ( + oss_cluster_maint_notifications_handler + ) + self._update_connection_kwargs_for_maint_notifications( + oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler + ) + self._maint_notifications_pool_handler = None + else: + self._oss_cluster_maint_notifications_handler = None + self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( + self, maint_notifications_config + ) - self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( - self, maint_notifications_config - ) - - self._update_connection_kwargs_for_maint_notifications( - self._maint_notifications_pool_handler - ) + self._update_connection_kwargs_for_maint_notifications( + maint_notifications_pool_handler=self._maint_notifications_pool_handler + ) else: self._maint_notifications_pool_handler = None + self._oss_cluster_maint_notifications_handler = None @property @abstractmethod @@ -2085,16 +2190,25 @@ def maint_notifications_enabled(self): The maintenance notifications config is stored in the pool handler. If the pool handler is not set, the maintenance notifications are not enabled. """ - maint_notifications_config = ( - self._maint_notifications_pool_handler.config - if self._maint_notifications_pool_handler - else None - ) + if self._oss_cluster_maint_notifications_handler: + maint_notifications_config = ( + self._oss_cluster_maint_notifications_handler.config + ) + else: + maint_notifications_config = ( + self._maint_notifications_pool_handler.config + if self._maint_notifications_pool_handler + else None + ) return maint_notifications_config and maint_notifications_config.enabled def update_maint_notifications_config( - self, maint_notifications_config: MaintNotificationsConfig + self, + maint_notifications_config: MaintNotificationsConfig, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, ): """ Updates the maintenance notifications configuration. @@ -2110,37 +2224,59 @@ def update_maint_notifications_config( raise ValueError( "Cannot disable maintenance notifications after enabling them" ) - # first update pool settings - if not self._maint_notifications_pool_handler: - self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( - self, maint_notifications_config + if oss_cluster_maint_notifications_handler: + self._oss_cluster_maint_notifications_handler = ( + oss_cluster_maint_notifications_handler ) else: - self._maint_notifications_pool_handler.config = maint_notifications_config + # first update pool settings + if not self._maint_notifications_pool_handler: + self._maint_notifications_pool_handler = MaintNotificationsPoolHandler( + self, maint_notifications_config + ) + else: + self._maint_notifications_pool_handler.config = ( + maint_notifications_config + ) # then update connection kwargs and existing connections self._update_connection_kwargs_for_maint_notifications( - self._maint_notifications_pool_handler + maint_notifications_pool_handler=self._maint_notifications_pool_handler, + oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler, ) self._update_maint_notifications_configs_for_connections( - self._maint_notifications_pool_handler + maint_notifications_pool_handler=self._maint_notifications_pool_handler, + oss_cluster_maint_notifications_handler=self._oss_cluster_maint_notifications_handler, ) def _update_connection_kwargs_for_maint_notifications( - self, maint_notifications_pool_handler: MaintNotificationsPoolHandler + self, + maint_notifications_pool_handler: Optional[ + MaintNotificationsPoolHandler + ] = None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, ): """ Update the connection kwargs for all future connections. """ if not self.maint_notifications_enabled(): return - - self.connection_kwargs.update( - { - "maint_notifications_pool_handler": maint_notifications_pool_handler, - "maint_notifications_config": maint_notifications_pool_handler.config, - } - ) + if maint_notifications_pool_handler: + self.connection_kwargs.update( + { + "maint_notifications_pool_handler": maint_notifications_pool_handler, + "maint_notifications_config": maint_notifications_pool_handler.config, + } + ) + if oss_cluster_maint_notifications_handler: + self.connection_kwargs.update( + { + "oss_cluster_maint_notifications_handler": oss_cluster_maint_notifications_handler, + "maint_notifications_config": oss_cluster_maint_notifications_handler.config, + } + ) # Store original connection parameters for maintenance notifications. if self.connection_kwargs.get("orig_host_address", None) is None: @@ -2159,25 +2295,56 @@ def _update_connection_kwargs_for_maint_notifications( ) def _update_maint_notifications_configs_for_connections( - self, maint_notifications_pool_handler: MaintNotificationsPoolHandler + self, + maint_notifications_pool_handler: Optional[ + MaintNotificationsPoolHandler + ] = None, + oss_cluster_maint_notifications_handler: Optional[ + OSSMaintNotificationsHandler + ] = None, ): """Update the maintenance notifications config for all connections in the pool.""" with self._get_pool_lock(): for conn in self._get_free_connections(): - conn.set_maint_notifications_pool_handler_for_connection( - maint_notifications_pool_handler - ) - conn.maint_notifications_config = ( - maint_notifications_pool_handler.config - ) + if oss_cluster_maint_notifications_handler: + # set cluster handler for conn + conn.set_maint_notifications_cluster_handler_for_connection( + oss_cluster_maint_notifications_handler + ) + conn.maint_notifications_config = ( + oss_cluster_maint_notifications_handler.config + ) + elif maint_notifications_pool_handler: + conn.set_maint_notifications_pool_handler_for_connection( + maint_notifications_pool_handler + ) + conn.maint_notifications_config = ( + maint_notifications_pool_handler.config + ) + else: + raise ValueError( + "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set" + ) conn.disconnect() for conn in self._get_in_use_connections(): - conn.set_maint_notifications_pool_handler_for_connection( - maint_notifications_pool_handler - ) - conn.maint_notifications_config = ( - maint_notifications_pool_handler.config - ) + if oss_cluster_maint_notifications_handler: + conn.maint_notifications_config = ( + oss_cluster_maint_notifications_handler.config + ) + conn._configure_maintenance_notifications( + oss_cluster_maint_notifications_handler=oss_cluster_maint_notifications_handler + ) + elif maint_notifications_pool_handler: + conn.set_maint_notifications_pool_handler_for_connection( + maint_notifications_pool_handler + ) + conn.maint_notifications_config = ( + maint_notifications_pool_handler.config + ) + else: + raise ValueError( + "Either maint_notifications_pool_handler or oss_cluster_maint_notifications_handler must be set" + ) conn.mark_for_reconnect() def _should_update_connection( @@ -2450,7 +2617,7 @@ def __init__( self._cache_factory = cache_factory if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): - if self._connection_kwargs.get("protocol") not in [3, "3"]: + if not check_protocol_version(self._connection_kwargs.get("protocol"), 3): raise RedisError("Client caching is only supported with RESP version 3") cache = self._connection_kwargs.get("cache") diff --git a/redis/event.py b/redis/event.py index 03c72c6370..18a4b80af9 100644 --- a/redis/event.py +++ b/redis/event.py @@ -6,6 +6,7 @@ from redis.auth.token import TokenInterface from redis.credentials import CredentialProvider, StreamingCredentialProvider +from redis.utils import check_protocol_version class EventListenerInterface(ABC): @@ -427,7 +428,7 @@ def __init__(self): def listen(self, event: AfterPubSubConnectionInstantiationEvent): if isinstance( event.pubsub_connection.credential_provider, StreamingCredentialProvider - ) and event.pubsub_connection.get_protocol() in [3, "3"]: + ) and check_protocol_version(event.pubsub_connection.get_protocol(), 3): self._event = event self._connection = event.pubsub_connection self._connection_pool = event.connection_pool diff --git a/redis/maint_notifications.py b/redis/maint_notifications.py index dbb97d37b9..cea0da47bd 100644 --- a/redis/maint_notifications.py +++ b/redis/maint_notifications.py @@ -10,7 +10,7 @@ from redis.typing import Number if TYPE_CHECKING: - from redis.cluster import NodesManager + from redis.cluster import MaintNotificationsAbstractRedisCluster class MaintenanceState(enum.Enum): @@ -463,8 +463,8 @@ class OSSNodeMigratedNotification(MaintenanceNotification): Args: id (int): Unique identifier for this notification - node_address (Optional[str]): Address of the node that has - completed migration - this is the destination node. + node_address (Optional[str]): Address of the node that has completed migration + in the format "host:port" slots (Optional[List[int]]): List of slots that have been migrated """ @@ -473,7 +473,7 @@ class OSSNodeMigratedNotification(MaintenanceNotification): def __init__( self, id: int, - node_address: Optional[str] = None, + node_address: str, slots: Optional[List[int]] = None, ): super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL) @@ -935,12 +935,30 @@ def handle_maintenance_completed_notification(self): class OSSMaintNotificationsHandler: def __init__( - self, nodes_manager: "NodesManager", config: MaintNotificationsConfig + self, + cluster_client: "MaintNotificationsAbstractRedisCluster", + config: MaintNotificationsConfig, ) -> None: - self.nodes_manager = nodes_manager + self.cluster_client = cluster_client self.config = config self._processed_notifications = set() + self._in_progress = set() self._lock = threading.RLock() + self.connection = None + + def set_connection(self, connection: "MaintNotificationsAbstractConnection"): + self.connection = connection + + def get_handler_for_connection(self): + # Copy all data that should be shared between connections + # but each connection should have its own pool handler + # since each connection can be in a different state + copy = OSSMaintNotificationsHandler(self.cluster_client, self.config) + copy._processed_notifications = self._processed_notifications + copy._in_progress = self._in_progress + copy._lock = self._lock + copy.connection = None + return copy def remove_expired_notifications(self): with self._lock: @@ -958,4 +976,58 @@ def handle_oss_maintenance_completed_notification( self, notification: OSSNodeMigratedNotification ): self.remove_expired_notifications() - logging.info(f"Received OSS maintenance completed notification: {notification}") + + with self._lock: + if ( + notification in self._in_progress + or notification in self._processed_notifications + ): + # we are already handling this notification or it has already been processed + # we should skip in_progress notification since when we reinitialize the cluster + # we execute a CLUSTER SLOTS command that can use a different connection + # that has also has the notification and we don't want to + # process the same notification twice + return + + self._in_progress.add(notification) + + # get the node to which the connection is connected + # before refreshing the cluster topology + current_node = self.cluster_client.nodes_manager.get_node( + host=self.connection.host, port=self.connection.port + ) + + # Updates the cluster slots cache with the new slots mapping + # This will also update the nodes cache with the new nodes mapping + new_node_host, new_node_port = notification.node_address.split(":") + self.cluster_client.nodes_manager.initialize( + disconnect_startup_nodes_pools=False, + additional_startup_nodes_info=[(new_node_host, int(new_node_port))], + ) + + if ( + current_node + not in self.cluster_client.nodes_manager.nodes_cache.values() + ): + # disconnect all free connections to the node + for conn in current_node.redis_connection.connection_pool._get_free_connections(): + conn.disconnect() + # mark for reconnect all in use connections to the node - this will force them to + # disconnect after they complete their current commands + for conn in current_node.redis_connection.connection_pool._get_in_use_connections(): + conn.mark_for_reconnect() + else: + if self.config.is_relaxed_timeouts_enabled(): + # reset the timeouts for the node to which the connection is connected + # TODO: add check if other maintenance ops are in progress for the same node - CAE-1038 + # and if so, don't reset the timeouts + for conn in ( + *current_node.redis_connection.connection_pool._get_in_use_connections(), + *current_node.redis_connection.connection_pool._get_free_connections(), + ): + conn.update_current_socket_timeout(relaxed_timeout=-1) + conn.maintenance_state = MaintenanceState.NONE + + # mark the notification as processed + self._processed_notifications.add(notification) + self._in_progress.remove(notification) diff --git a/redis/utils.py b/redis/utils.py index 37a11a74f5..b134e1c103 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -213,6 +213,14 @@ def _set_info_logger(): logger.addHandler(handler) +def check_protocol_version( + protocol: Optional[Union[str, int]], expected_version: int = 3 +) -> bool: + if protocol is None: + return False + return int(protocol) == expected_version + + def get_lib_version(): try: libver = metadata.version("redis") diff --git a/tests/maint_notifications/proxy_server_helpers.py b/tests/maint_notifications/proxy_server_helpers.py index b3397beefa..7358f078d8 100644 --- a/tests/maint_notifications/proxy_server_helpers.py +++ b/tests/maint_notifications/proxy_server_helpers.py @@ -1,11 +1,10 @@ import base64 +from dataclasses import dataclass import logging import re from typing import Union from redis.http.http_client import HttpClient, HttpError -# from urllib.request import Request, urlopen -# from urllib.error import URLError class RespTranslator: @@ -34,7 +33,7 @@ def cluster_slots_to_resp(resp: str) -> str: ) @staticmethod - def smigrating_to_resp(resp: str) -> str: + def oss_maint_notification_to_resp(resp: str) -> str: """Convert query to RESP format.""" return ( f">{len(resp.split())}\r\n" @@ -45,6 +44,14 @@ def smigrating_to_resp(resp: str) -> str: ) +@dataclass +class SlotsRange: + host: str + port: int + start_slot: int + end_slot: int + + class ProxyInterceptorHelper: """Helper class for intercepting socket calls and managing interceptor server.""" @@ -52,6 +59,7 @@ def __init__(self, server_url: str = "http://localhost:4000"): self.server_url = server_url self._resp_translator = RespTranslator() self.http_client = HttpClient() + self._interceptors = list() def cleanup_interceptors(self, *names: str): """ @@ -60,56 +68,67 @@ def cleanup_interceptors(self, *names: str): Args: names: Names of the interceptors to reset """ - for name in names: + if not names: + names = self._interceptors + for name in tuple(names): self._reset_interceptor(name) - def set_cluster_nodes(self, name: str, nodes: list[tuple[str, int]]) -> str: + def set_cluster_slots( + self, + name: str, + slots_ranges: list[SlotsRange], + ) -> str: """ - Set cluster nodes by intercepting CLUSTER SLOTS command. + Set cluster slots and nodes by intercepting CLUSTER SLOTS command. This method creates an interceptor that intercepts CLUSTER SLOTS commands - and returns a modified topology with the provided nodes. + and returns a modified topology with the provided data. Args: name: Name of the interceptor - nodes: List of (host, port) tuples representing the cluster nodes + slots_ranges: List of SlotsRange objects representing the cluster + nodes and slots coverage Returns: The interceptor name that was created Example: interceptor = ProxyInterceptorHelper(None, "http://localhost:4000") - interceptor_name = interceptor.set_cluster_nodes( + interceptor.set_cluster_slots( "test_topology", - [("127.0.0.1", 6379), ("127.0.0.1", 6380), ("127.0.0.1", 6381)] + [ + SlotsRange("127.0.0.1", 6379, 0, 5000), + SlotsRange("127.0.0.1", 6380, 5001, 10000), + SlotsRange("127.0.0.1", 6381, 10001, 16383), + ] ) """ # Build RESP response for CLUSTER SLOTS # Format: * for each range: *3 :start :end *3 $ : $ - resp_parts = [f"*{len(nodes)}"] - - # For simplicity, distribute slots evenly across nodes - total_slots = 16384 - slots_per_node = total_slots // len(nodes) - - for i, (host, port) in enumerate(nodes): - start_slot = i * slots_per_node - end_slot = ( - (i + 1) * slots_per_node - 1 if i < len(nodes) - 1 else total_slots - 1 - ) + resp_parts = [f"*{len(slots_ranges)}"] + for slots_range in slots_ranges: # Node info: *3 for (host, port, id) resp_parts.append("*3") - resp_parts.append(f":{start_slot}") - resp_parts.append(f":{end_slot}") - - # Node details: *3 for (host, port, id) - resp_parts.append("*3") - resp_parts.append(f"${len(host)}") - resp_parts.append(host) - resp_parts.append(f":{port}") - resp_parts.append("$13") - resp_parts.append(f"proxy-id-{port}") + # 1st elem --> start slot + resp_parts.append(f":{slots_range.start_slot}") + # 2nd elem --> end slot + resp_parts.append(f":{slots_range.end_slot}") + + # 3rd elem --> list with node details: *4 for (host, port, id, empty hash) + resp_parts.append("*4") + # 1st elem --> host + resp_parts.append(f"${len(slots_range.host)}") + resp_parts.append(f"{slots_range.host}") + # 2nd elem --> port + resp_parts.append(f":{slots_range.port}") + # 3rd elem --> node id + node_id = f"proxy-id-{slots_range.port}" + resp_parts.append(f"${len(node_id)}") + resp_parts.append(node_id) + # 4th elem --> empty hash + resp_parts.append("$0") + resp_parts.append("") response = "\r\n".join(resp_parts) + "\r\n" @@ -257,7 +276,10 @@ def _add_interceptor( proxy_response = self.http_client.post( url, json_body=payload, headers=headers ) - return proxy_response.json() + self._interceptors.append(name) + if isinstance(proxy_response, dict): + return proxy_response + return proxy_response.json() if proxy_response else {} except HttpError as e: raise RuntimeError(f"Failed to add interceptor: {e}") @@ -268,4 +290,4 @@ def _reset_interceptor(self, name: str): Args: name: Name of the interceptor to reset """ - self._add_interceptor(name, "", "") + self._add_interceptor(name, "no_match", "") diff --git a/tests/maint_notifications/test_cluster_maint_notifications_handling.py b/tests/maint_notifications/test_cluster_maint_notifications_handling.py index 8e8d53b62f..97e91d2a05 100644 --- a/tests/maint_notifications/test_cluster_maint_notifications_handling.py +++ b/tests/maint_notifications/test_cluster_maint_notifications_handling.py @@ -11,12 +11,15 @@ from tests.maint_notifications.proxy_server_helpers import ( ProxyInterceptorHelper, RespTranslator, + SlotsRange, ) NODE_PORT_1 = 15379 NODE_PORT_2 = 15380 NODE_PORT_3 = 15381 +NODE_PORT_NEW = 15382 + # Initial cluster node configuration for proxy-based tests PROXY_CLUSTER_NODES = [ ClusterNode("127.0.0.1", NODE_PORT_1), @@ -34,6 +37,7 @@ def _create_cluster_client( enable_cache=False, max_connections=10, maint_config=None, + protocol=3, ) -> RedisCluster: """Create a RedisCluster instance with mocked sockets.""" if maint_config is None and hasattr(self, "config") and self.config is not None: @@ -44,7 +48,7 @@ def _create_cluster_client( kwargs = {"cache_config": CacheConfig()} test_redis_client = RedisCluster( - protocol=3, + protocol=protocol, startup_nodes=PROXY_CLUSTER_NODES, maint_notifications_config=maint_config, connection_pool_class=pool_class, @@ -214,14 +218,37 @@ def test_config_with_cache_enabled(self): def test_none_config_default_behavior(self): """ - Test that when maint_notifications_config=None, the system works without errors. + Test that when maint_notifications_config=None, it will be initialized with default values. """ cluster = self._create_cluster_client(maint_config=None) try: # Verify cluster is created successfully assert cluster.nodes_manager is not None + # for protocol 3, maint_notifications_config should be initialized with default values + assert cluster.nodes_manager.maint_notifications_config is not None + assert cluster.nodes_manager.maint_notifications_config.enabled == "auto" + assert len(cluster.nodes_manager.nodes_cache) > 0 + # Verify we can execute commands without errors + cluster.set("test", "VAL") + res = cluster.get("test") + assert res == b"VAL" + finally: + cluster.close() + + def test_none_config_default_behavior_for_protocol_2(self): + """ + Test that when maint_notifications_config=None and protocol=2, + it will not be initialized. + """ + cluster = self._create_cluster_client(protocol=2) + + try: + # Verify cluster is created successfully + assert cluster.nodes_manager is not None + # for protocol 2, maint_notifications_config should not be created assert cluster.nodes_manager.maint_notifications_config is None + assert len(cluster.nodes_manager.nodes_cache) > 0 # Verify we can execute commands without errors cluster.set("test", "VAL") @@ -285,6 +312,95 @@ def test_config_with_pipeline_operations(self): cluster.close() +class TestClusterMaintNotificationsHandler(TestClusterMaintNotificationsBase): + """Test OSSMaintNotificationsHandler propagation with RedisCluster.""" + + def _validate_connection_handlers( + self, conn, cluster_client, config, is_cache_conn=False + ): + """Helper method to validate connection handlers are properly set.""" + # Test that the oss cluster handler function is correctly set + oss_cluster_parser_handler_set_for_con = ( + conn._parser.oss_cluster_maint_push_handler_func + ) + assert oss_cluster_parser_handler_set_for_con is not None + assert hasattr(oss_cluster_parser_handler_set_for_con, "__self__") + assert hasattr(oss_cluster_parser_handler_set_for_con, "__func__") + assert oss_cluster_parser_handler_set_for_con.__self__.connection is conn + assert ( + oss_cluster_parser_handler_set_for_con.__self__.cluster_client + is cluster_client + ) + assert ( + oss_cluster_parser_handler_set_for_con.__self__._lock + is cluster_client._oss_cluster_maint_notifications_handler._lock + ) + assert ( + oss_cluster_parser_handler_set_for_con.__self__._processed_notifications + is cluster_client._oss_cluster_maint_notifications_handler._processed_notifications + ) + assert ( + oss_cluster_parser_handler_set_for_con.__func__ + is cluster_client._oss_cluster_maint_notifications_handler.handle_notification.__func__ + ) + + # Test that the maintenance handler function is correctly set + parser_maint_handler_set_for_con = conn._parser.maintenance_push_handler_func + assert parser_maint_handler_set_for_con is not None + assert hasattr(parser_maint_handler_set_for_con, "__self__") + assert hasattr(parser_maint_handler_set_for_con, "__func__") + # The maintenance handler should be bound to the connection's + # maintenance notification connection handler + assert ( + parser_maint_handler_set_for_con.__self__ + is conn._maint_notifications_connection_handler + ) + assert ( + parser_maint_handler_set_for_con.__func__ + is conn._maint_notifications_connection_handler.handle_notification.__func__ + ) + + # Validate that the connection's maintenance handler has the same config object + assert conn._maint_notifications_connection_handler.config is config + + def test_oss_maint_handler_propagation(self): + """Test that OSSMaintNotificationsHandler is propagated to all connections.""" + cluster = self._create_cluster_client() + # Verify all nodes have the handler + for node in cluster.nodes_manager.nodes_cache.values(): + assert node.redis_connection is not None + assert node.redis_connection.connection_pool is not None + for conn in ( + *node.redis_connection.connection_pool._get_in_use_connections(), + *node.redis_connection.connection_pool._get_free_connections(), + ): + assert conn._oss_cluster_maint_notifications_handler is not None + assert conn._oss_cluster_maint_notifications_handler.connection is conn + self._validate_connection_handlers( + conn, cluster, cluster.maint_notifications_config + ) + + def test_oss_maint_handler_propagation_cache_enabled(self): + """Test that OSSMaintNotificationsHandler is propagated to all connections.""" + cluster = self._create_cluster_client(enable_cache=True) + # Verify all nodes have the handler + for node in cluster.nodes_manager.nodes_cache.values(): + assert node.redis_connection is not None + assert node.redis_connection.connection_pool is not None + for conn in ( + *node.redis_connection.connection_pool._get_in_use_connections(), + *node.redis_connection.connection_pool._get_free_connections(), + ): + assert conn._conn._oss_cluster_maint_notifications_handler is not None + assert ( + conn._conn._oss_cluster_maint_notifications_handler.connection + is conn._conn + ) + self._validate_connection_handlers( + conn._conn, cluster, cluster.maint_notifications_config + ) + + class TestClusterMaintNotificationsHandlingBase(TestClusterMaintNotificationsBase): """Base class for maintenance notifications handling tests.""" @@ -317,6 +433,21 @@ class ConnectionStateExpectation: class TestClusterMaintNotificationsHandling(TestClusterMaintNotificationsHandlingBase): """Test maintenance notifications handling with RedisCluster.""" + def _warm_up_connection_pools( + self, cluster: RedisCluster, created_connections_count: int = 3 + ): + """Warm up connection pools by getting a connection from each pool.""" + for node in cluster.nodes_manager.nodes_cache.values(): + node_connections = [] + for _ in range(created_connections_count): + node_connections.append( + node.redis_connection.connection_pool.get_connection() + ) + for conn in node_connections: + node.redis_connection.connection_pool.release(conn) + + node_connections.clear() + def _get_expected_node_state( self, expectations_list: List[ConnectionStateExpectation], node_port: int ) -> Optional[ConnectionStateExpectation]: @@ -361,23 +492,26 @@ def _validate_connections_states( changed_connections_count += 1 assert changed_connections_count == expected_state.changed_connections_count - def test_receive_oss_maintenance_notification(self): - """Test receiving an OSS maintenance notification.""" - # get three connections from each node - for node in self.cluster.nodes_manager.nodes_cache.values(): - node_connections = [] - for _ in range(3): - node_connections.append( - node.redis_connection.connection_pool.get_connection() - ) - for conn in node_connections: - node.redis_connection.connection_pool.release(conn) + def _validate_removed_node_connections(self, node): + """Validate connections in a removed node.""" + assert node.redis_connection is not None + connection_pool = node.redis_connection.connection_pool + assert connection_pool is not None - node_connections.clear() + # validate all connections are disconnected or marked for reconnect + for conn in connection_pool._get_free_connections(): + assert conn._sock is None + for conn in connection_pool._get_in_use_connections(): + assert conn.should_reconnect() + + def test_receive_smigrating_notification(self): + """Test receiving an OSS maintenance notification.""" + # warm up connection pools + self._warm_up_connection_pools(self.cluster, created_connections_count=3) # send a notification to node 1 - notification = RespTranslator.smigrating_to_resp( - "SMIGRATING 12 TO 127.0.0.1:15380 <123,456,5000-7000>" + notification = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 <123,456,5000-7000>" ) self.proxy_helper.send_notification(NODE_PORT_1, notification) @@ -420,24 +554,424 @@ def test_receive_oss_maintenance_notification(self): ], ) - def test_receive_maint_notification(self): - """Test receiving a maintenance notification.""" - self.cluster.set("test", "VAL") - pubsub = self.cluster.pubsub() - pubsub.subscribe("test") - test_msg = pubsub.get_message(ignore_subscribe_messages=True, timeout=10) - print(test_msg) + def test_receive_smigrating_with_disabled_relaxed_timeout(self): + """Test receiving an OSS maintenance notification with disabled relaxed timeout.""" + # Create config with disabled relaxed timeout + disabled_config = MaintNotificationsConfig( + enabled="auto", + relaxed_timeout=-1, # This means the relaxed timeout is Disabled + ) + cluster = self._create_cluster_client(maint_config=disabled_config) + + # warm up connection pools + self._warm_up_connection_pools(cluster, created_connections_count=3) + + # send a notification to node 1 + notification = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 <123,456,5000-7000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, notification) + + # validate no timeout is relaxed on any connection + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, changed_connections_count=0 + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, changed_connections_count=0 + ), + ConnectionStateExpectation( + node_port=NODE_PORT_3, changed_connections_count=0 + ), + ], + ) + + def test_receive_smigrated_notification(self): + """Test receiving an OSS maintenance completed notification.""" + # create three connections in each node's connection pool + self._warm_up_connection_pools(self.cluster, created_connections_count=3) + + self.proxy_helper.set_cluster_slots( + "test_topology", + [ + SlotsRange("0.0.0.0", NODE_PORT_NEW, 0, 5460), + SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 10922), + SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + ], + ) + # send a notification to node 1 + notification = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 12 127.0.0.1:15380 <123,456,5000-7000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, notification) + + # execute a command that will receive the notification + res = self.cluster.set("anyprefix:{3}:k", "VAL") + assert res is True + + # validate the cluster topology was updated + new_node = self.cluster.nodes_manager.get_node( + host="0.0.0.0", port=NODE_PORT_NEW + ) + assert new_node is not None + + def test_smigrating_smigrated_on_two_nodes_without_node_replacement(self): + """Test receiving an OSS maintenance notification on two nodes without node replacement.""" + # warm up connection pools - create several connections in each pool + self._warm_up_connection_pools(self.cluster, created_connections_count=3) + + node_1 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_1) + node_2 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_2) + + smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 <123,2000-3000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, changed_connections_count=0 + ), + ], + ) + + smigrating_node_2 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 13 <8000-9000>" + ) + self.proxy_helper.send_notification(NODE_PORT_2, smigrating_node_2) + + # execute command with node 2 connection + self.cluster.set("anyprefix:{1}:k", "VAL") + + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) + smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 14 0.0.0.0:15381 <123,2000-3000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) + + smigrated_node_2 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 15 0.0.0.0:15381 <8000-9000>" + ) + self.proxy_helper.send_notification(NODE_PORT_2, smigrated_node_2) + + self.proxy_helper.set_cluster_slots( + "test_topology", + [ + SlotsRange("0.0.0.0", NODE_PORT_1, 0, 122), + SlotsRange("0.0.0.0", NODE_PORT_3, 123, 123), + SlotsRange("0.0.0.0", NODE_PORT_1, 124, 2000), + SlotsRange("0.0.0.0", NODE_PORT_3, 2001, 3000), + SlotsRange("0.0.0.0", NODE_PORT_1, 3001, 5460), + SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 10922), + SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + ], + ) + + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + + # validate the cluster topology was updated + # validate old nodes are there + assert node_1 in self.cluster.nodes_manager.nodes_cache.values() + assert node_2 in self.cluster.nodes_manager.nodes_cache.values() + # validate changed slot is assigned to node 3 + assert self.cluster.nodes_manager.get_node_from_slot( + 123 + ) == self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_3) + # validate the connections are in the correct state + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=0, + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) + + self.proxy_helper.set_cluster_slots( + "test_topology", + [ + SlotsRange("0.0.0.0", NODE_PORT_1, 0, 122), + SlotsRange("0.0.0.0", NODE_PORT_3, 123, 123), + SlotsRange("0.0.0.0", NODE_PORT_1, 124, 2000), + SlotsRange("0.0.0.0", NODE_PORT_3, 2001, 3000), + SlotsRange("0.0.0.0", NODE_PORT_1, 3001, 5460), + SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 7000), + SlotsRange("0.0.0.0", NODE_PORT_3, 7001, 8000), + SlotsRange("0.0.0.0", NODE_PORT_2, 8001, 10922), + SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + ], + ) + # execute command with node 2 connection + self.cluster.set("anyprefix:{1}:k", "VAL") + + # validate old nodes are there + assert node_1 in self.cluster.nodes_manager.nodes_cache.values() + assert node_2 in self.cluster.nodes_manager.nodes_cache.values() + # validate slot changes are reflected + assert self.cluster.nodes_manager.get_node_from_slot( + 8000 + ) == self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_3) + + # validate the connections are in the correct state + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=0, + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, + changed_connections_count=0, + ), + ], + ) + + def test_smigrating_smigrated_on_two_nodes_with_node_replacements(self): + """Test receiving an OSS maintenance notification on two nodes with node replacement.""" + # warm up connection pools - create several connections in each pool + self._warm_up_connection_pools(self.cluster, created_connections_count=3) + + node_1 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_1) + node_2 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_2) + node_3 = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=NODE_PORT_3) + + smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 <0-5460>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, changed_connections_count=0 + ), + ], + ) - # Try to send a push notification to the clients of given server node - # Server node is defined by its port with the local test environment - # The message should be in the format: - # >3\r\n$7\r\nmessage\r\n$3\r\nfoo\r\n$4\r\neeee\r - notification = RespTranslator.smigrating_to_resp( - "TEST_NOTIFICATION 12182 127.0.0.1:15380" + smigrating_node_2 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 13 <5461-10922>" ) - self.proxy_helper.send_notification(pubsub.connection.port, notification) - res = self.proxy_helper.get_connections() - print(res) + self.proxy_helper.send_notification(NODE_PORT_2, smigrating_node_2) - test_msg = pubsub.get_message(timeout=1) - print(test_msg) + # execute command with node 2 connection + self.cluster.set("anyprefix:{1}:k", "VAL") + + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ConnectionStateExpectation( + node_port=NODE_PORT_2, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) + + smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 14 0.0.0.0:15382 <0-5460>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) + + smigrated_node_2 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 15 0.0.0.0:15382 <5461-10922>" + ) + self.proxy_helper.send_notification(NODE_PORT_2, smigrated_node_2) + self.proxy_helper.set_cluster_slots( + "test_topology", + [ + SlotsRange("0.0.0.0", 15382, 0, 5460), + SlotsRange("0.0.0.0", NODE_PORT_2, 5461, 10922), + SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + ], + ) + + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + + # validate node 1 is removed + assert node_1 not in self.cluster.nodes_manager.nodes_cache.values() + # validate node 2 is still there + assert node_2 in self.cluster.nodes_manager.nodes_cache.values() + # validate new node is added + new_node = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=15382) + assert new_node is not None + assert new_node.redis_connection is not None + # validate a slot from the changed range is assigned to the new node + assert self.cluster.nodes_manager.get_node_from_slot( + 123 + ) == self.cluster.nodes_manager.get_node(host="0.0.0.0", port=15382) + + # validate the connections are in the correct state + self._validate_removed_node_connections(node_1) + + # validate the connections are in the correct state + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_2, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) + + self.proxy_helper.set_cluster_slots( + "test_topology", + [ + SlotsRange("0.0.0.0", 15382, 0, 5460), + SlotsRange("0.0.0.0", 15383, 5461, 10922), + SlotsRange("0.0.0.0", NODE_PORT_3, 10923, 16383), + ], + ) + # execute command with node 2 connection + self.cluster.set("anyprefix:{1}:k", "VAL") + + # validate node 2 is removed + assert node_2 not in self.cluster.nodes_manager.nodes_cache.values() + # validate node 3 is still there + assert node_3 in self.cluster.nodes_manager.nodes_cache.values() + # validate new node is added + new_node = self.cluster.nodes_manager.get_node(host="0.0.0.0", port=15383) + assert new_node is not None + assert new_node.redis_connection is not None + # validate a slot from the changed range is assigned to the new node + assert self.cluster.nodes_manager.get_node_from_slot( + 8000 + ) == self.cluster.nodes_manager.get_node(host="0.0.0.0", port=15383) + + # validate the connections in removed node are in the correct state + self._validate_removed_node_connections(node_2) + + def test_smigrating_smigrated_on_the_same_node_two_slot_ranges( + self, + ): + """ + Test receiving an OSS maintenance notification on the same node twice. + The focus here is to validate that the timeouts are not unrelaxed if a second + migration is in progress + """ + # warm up connection pools - create several connections in each pool + self._warm_up_connection_pools(self.cluster, created_connections_count=1) + + smigrating_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 12 <1000-2000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1) + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) + + smigrating_node_1_2 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATING 13 <3000-4000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrating_node_1_2) + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=1, + state=MaintenanceState.MAINTENANCE, + relaxed_timeout=self.config.relaxed_timeout, + ), + ], + ) + smigrated_node_1 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 14 0.0.0.0:15380 <1000-2000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1) + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + # this functionality is part of CAE-1038 and will be added later + # validate the timeout is still relaxed + # self._validate_connections_states( + # self.cluster, + # [ + # ConnectionStateExpectation( + # node_port=NODE_PORT_1, + # changed_connections_count=1, + # state=MaintenanceState.MAINTENANCE, + # relaxed_timeout=self.config.relaxed_timeout, + # ), + # ], + # ) + smigrated_node_1_2 = RespTranslator.oss_maint_notification_to_resp( + "SMIGRATED 15 0.0.0.0:15381 <3000-4000>" + ) + self.proxy_helper.send_notification(NODE_PORT_1, smigrated_node_1_2) + # execute command with node 1 connection + self.cluster.set("anyprefix:{3}:k", "VAL") + self._validate_connections_states( + self.cluster, + [ + ConnectionStateExpectation( + node_port=NODE_PORT_1, + changed_connections_count=0, + ), + ], + ) diff --git a/tests/maint_notifications/test_maint_notifications.py b/tests/maint_notifications/test_maint_notifications.py index 4ae5ae60a5..adb9ebb5ea 100644 --- a/tests/maint_notifications/test_maint_notifications.py +++ b/tests/maint_notifications/test_maint_notifications.py @@ -493,40 +493,44 @@ class TestOSSNodeMigratedNotification: def test_init_with_defaults(self): """Test OSSNodeMigratedNotification initialization with default values.""" with patch("time.monotonic", return_value=1000): - notification = OSSNodeMigratedNotification(id=1) + notification = OSSNodeMigratedNotification( + id=1, node_address="127.0.0.1:6380" + ) assert notification.id == 1 assert notification.ttl == OSSNodeMigratedNotification.DEFAULT_TTL assert notification.creation_time == 1000 - assert notification.node_address is None + assert notification.node_address == "127.0.0.1:6380" assert notification.slots is None def test_init_with_all_parameters(self): """Test OSSNodeMigratedNotification initialization with all parameters.""" with patch("time.monotonic", return_value=1000): slots = [1, 2, 3, 4, 5] + node_address = "127.0.0.1:6380" notification = OSSNodeMigratedNotification( id=1, - node_address="127.0.0.1:6380", + node_address=node_address, slots=slots, ) assert notification.id == 1 assert notification.ttl == OSSNodeMigratedNotification.DEFAULT_TTL assert notification.creation_time == 1000 - assert notification.node_address == "127.0.0.1:6380" + assert notification.node_address == node_address assert notification.slots == slots def test_default_ttl(self): """Test that DEFAULT_TTL is used correctly.""" assert OSSNodeMigratedNotification.DEFAULT_TTL == 30 - notification = OSSNodeMigratedNotification(id=1) + notification = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") assert notification.ttl == 30 def test_repr(self): """Test OSSNodeMigratedNotification string representation.""" with patch("time.monotonic", return_value=1000): + node_address = "127.0.0.1:6380" notification = OSSNodeMigratedNotification( id=1, - node_address="127.0.0.1:6380", + node_address=node_address, slots=[1, 2, 3], ) @@ -555,13 +559,13 @@ def test_equality_same_id_and_type(self): def test_equality_different_id(self): """Test inequality for notifications with different id.""" - notification1 = OSSNodeMigratedNotification(id=1) - notification2 = OSSNodeMigratedNotification(id=2) + notification1 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") + notification2 = OSSNodeMigratedNotification(id=2, node_address="127.0.0.1:6380") assert notification1 != notification2 def test_equality_different_type(self): """Test inequality for notifications of different types.""" - notification1 = OSSNodeMigratedNotification(id=1) + notification1 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") notification2 = NodeMigratedNotification(id=1) assert notification1 != notification2 @@ -582,16 +586,16 @@ def test_hash_same_id_and_type(self): def test_hash_different_id(self): """Test hash for notifications with different id.""" - notification1 = OSSNodeMigratedNotification(id=1) - notification2 = OSSNodeMigratedNotification(id=2) + notification1 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") + notification2 = OSSNodeMigratedNotification(id=2, node_address="127.0.0.1:6380") assert hash(notification1) != hash(notification2) def test_in_set(self): """Test that notifications can be used in sets.""" - notification1 = OSSNodeMigratedNotification(id=1) - notification2 = OSSNodeMigratedNotification(id=1) - notification3 = OSSNodeMigratedNotification(id=2) - notification4 = OSSNodeMigratedNotification(id=2) + notification1 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") + notification2 = OSSNodeMigratedNotification(id=1, node_address="127.0.0.1:6380") + notification3 = OSSNodeMigratedNotification(id=2, node_address="127.0.0.1:6381") + notification4 = OSSNodeMigratedNotification(id=2, node_address="127.0.0.1:6381") notification_set = {notification1, notification2, notification3, notification4} assert (