Skip to content

Commit d73cd35

Browse files
committed
Adding maintenance state to connections. Migrating and Migrated are not processed in in Moving state. Tests are updated
1 parent ce31ec7 commit d73cd35

File tree

5 files changed

+237
-50
lines changed

5 files changed

+237
-50
lines changed

redis/_parsers/hiredis.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def read_response(self, disable_decoding=False, push_request=False):
152152
disable_decoding=disable_decoding,
153153
push_request=push_request,
154154
)
155+
return response
155156

156157
if disable_decoding:
157158
response = self._reader.gets(False)

redis/connection.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
MaintenanceEventConnectionHandler,
4343
MaintenanceEventPoolHandler,
4444
MaintenanceEventsConfig,
45+
MaintenanceState,
4546
)
4647
from .retry import Retry
4748
from .utils import (
@@ -285,6 +286,7 @@ def __init__(
285286
maintenance_events_config: Optional[MaintenanceEventsConfig] = None,
286287
tmp_host_address: Optional[str] = None,
287288
tmp_relax_timeout: Optional[float] = -1,
289+
maintenance_state: "MaintenanceState" = MaintenanceState.NONE,
288290
):
289291
"""
290292
Initialize a new Connection.
@@ -374,6 +376,7 @@ def __init__(
374376
self._should_reconnect = False
375377
self.tmp_host_address = tmp_host_address
376378
self.tmp_relax_timeout = tmp_relax_timeout
379+
self.maintenance_state = maintenance_state
377380

378381
def __repr__(self):
379382
repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
@@ -835,6 +838,9 @@ def update_tmp_settings(
835838
if tmp_relax_timeout is not SENTINEL:
836839
self.tmp_relax_timeout = tmp_relax_timeout
837840

841+
def set_maintenance_state(self, state: "MaintenanceState"):
842+
self.maintenance_state = state
843+
838844

839845
class Connection(AbstractConnection):
840846
"Manages TCP communication to and from a Redis server"
@@ -1724,11 +1730,18 @@ def make_connection(self) -> "ConnectionInterface":
17241730
raise MaxConnectionsError("Too many connections")
17251731
self._created_connections += 1
17261732

1733+
# Pass current maintenance_state to new connections
1734+
maintenance_state = self.connection_kwargs.get(
1735+
"maintenance_state", MaintenanceState.NONE
1736+
)
1737+
kwargs = dict(self.connection_kwargs)
1738+
kwargs["maintenance_state"] = maintenance_state
1739+
17271740
if self.cache is not None:
17281741
return CacheProxyConnection(
1729-
self.connection_class(**self.connection_kwargs), self.cache, self._lock
1742+
self.connection_class(**kwargs), self.cache, self._lock
17301743
)
1731-
return self.connection_class(**self.connection_kwargs)
1744+
return self.connection_class(**kwargs)
17321745

17331746
def release(self, connection: "Connection") -> None:
17341747
"Releases the connection back to the pool"
@@ -1953,6 +1966,16 @@ async def _mock(self, error: RedisError):
19531966
"""
19541967
pass
19551968

1969+
def set_maintenance_state_for_all(self, state: "MaintenanceState"):
1970+
with self._lock:
1971+
for conn in self._available_connections:
1972+
conn.set_maintenance_state(state)
1973+
for conn in self._in_use_connections:
1974+
conn.set_maintenance_state(state)
1975+
1976+
def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"):
1977+
self.connection_kwargs["maintenance_state"] = state
1978+
19561979

19571980
class BlockingConnectionPool(ConnectionPool):
19581981
"""
@@ -2047,15 +2070,20 @@ def make_connection(self):
20472070
if self._in_maintenance:
20482071
self._lock.acquire()
20492072
self._locked = True
2073+
# Pass current maintenance_state to new connections
2074+
maintenance_state = self.connection_kwargs.get(
2075+
"maintenance_state", MaintenanceState.NONE
2076+
)
2077+
kwargs = dict(self.connection_kwargs)
2078+
kwargs["maintenance_state"] = maintenance_state
20502079
if self.cache is not None:
20512080
connection = CacheProxyConnection(
2052-
self.connection_class(**self.connection_kwargs),
2081+
self.connection_class(**kwargs),
20532082
self.cache,
20542083
self._lock,
20552084
)
20562085
else:
2057-
connection = self.connection_class(**self.connection_kwargs)
2058-
2086+
connection = self.connection_class(**kwargs)
20592087
self._connections.append(connection)
20602088
return connection
20612089
finally:
@@ -2266,3 +2294,12 @@ def _update_maintenance_events_configs_for_connections(
22662294
def set_in_maintenance(self, in_maintenance: bool):
22672295
"""Set the maintenance mode for the connection pool."""
22682296
self._in_maintenance = in_maintenance
2297+
2298+
def set_maintenance_state_for_all(self, state: "MaintenanceState"):
2299+
with self._lock:
2300+
for conn in getattr(self, "_connections", []):
2301+
if conn:
2302+
conn.set_maintenance_state(state)
2303+
2304+
def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"):
2305+
self.connection_kwargs["maintenance_state"] = state

redis/maintenance_events.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import logging
23
import threading
34
import time
@@ -6,6 +7,13 @@
67

78
from redis.typing import Number
89

10+
11+
class MaintenanceState(enum.Enum):
12+
NONE = "none"
13+
MOVING = "moving"
14+
MIGRATING = "migrating"
15+
16+
917
if TYPE_CHECKING:
1018
from redis.connection import (
1119
BlockingConnectionPool,
@@ -351,6 +359,9 @@ def handle_node_moving_event(self, event: NodeMovingEvent):
351359
):
352360
if getattr(self.pool, "set_in_maintenance", False):
353361
self.pool.set_in_maintenance(True)
362+
# Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance)
363+
self.pool.set_maintenance_state_for_all(MaintenanceState.MOVING)
364+
self.pool.set_maintenance_state_in_kwargs(MaintenanceState.MOVING)
354365
# edit the config for new connections until the notification expires
355366
self.pool.update_connection_kwargs_with_tmp_settings(
356367
tmp_host_address=event.new_node_host,
@@ -368,7 +379,6 @@ def handle_node_moving_event(self, event: NodeMovingEvent):
368379
tmp_host_address=event.new_node_host,
369380
tmp_relax_timeout=self.config.relax_timeout,
370381
)
371-
372382
# take care for the inactive connections in the pool
373383
# delete them and create new ones
374384
self.pool.disconnect_and_reconfigure_free_connections(
@@ -388,16 +398,19 @@ def handle_node_moved_event(self):
388398
tmp_host_address=None,
389399
tmp_relax_timeout=-1,
390400
)
401+
# Clear state to NONE in kwargs immediately after updating tmp kwargs
402+
self.pool.set_maintenance_state_in_kwargs(MaintenanceState.NONE)
391403
with self.pool._lock:
392404
if self.config.is_relax_timeouts_enabled():
393405
# reset the timeout for existing connections
394406
self.pool.update_connections_current_timeout(
395407
relax_timeout=-1, include_free_connections=True
396408
)
397-
398409
self.pool.update_connections_tmp_settings(
399410
tmp_host_address=None, tmp_relax_timeout=-1
400411
)
412+
# Clear state to NONE for all connections
413+
self.pool.set_maintenance_state_for_all(MaintenanceState.NONE)
401414

402415

403416
class MaintenanceEventConnectionHandler:
@@ -416,17 +429,24 @@ def handle_event(self, event: MaintenanceEvent):
416429
logging.error(f"Unhandled event type: {event}")
417430

418431
def handle_migrating_event(self, notification: NodeMigratingEvent):
419-
if not self.config.is_relax_timeouts_enabled():
432+
if (
433+
self.connection.maintenance_state == MaintenanceState.MOVING
434+
or not self.config.is_relax_timeouts_enabled()
435+
):
420436
return
421-
437+
self.connection.set_maintenance_state(MaintenanceState.MIGRATING)
422438
# extend the timeout for all created connections
423439
self.connection.update_current_socket_timeout(self.config.relax_timeout)
424440
self.connection.update_tmp_settings(tmp_relax_timeout=self.config.relax_timeout)
425441

426442
def handle_migration_completed_event(self, notification: "NodeMigratedEvent"):
427-
if not self.config.is_relax_timeouts_enabled():
443+
# Only reset timeouts if state is not MOVING and relax timeouts are enabled
444+
if (
445+
self.connection.maintenance_state == MaintenanceState.MOVING
446+
or not self.config.is_relax_timeouts_enabled()
447+
):
428448
return
429-
449+
self.connection.set_maintenance_state(MaintenanceState.NONE)
430450
# Node migration completed - reset the connection
431451
# timeouts by providing -1 as the relax timeout
432452
self.connection.update_current_socket_timeout(-1)

tests/test_connection_pool.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import redis
1010
from redis.cache import CacheConfig
1111
from redis.connection import CacheProxyConnection, Connection, to_bool
12+
from redis.maintenance_events import MaintenanceState
1213
from redis.utils import SSL_AVAILABLE
1314

1415
from .conftest import (
@@ -53,10 +54,15 @@ def get_pool(
5354
return pool
5455

5556
def test_connection_creation(self):
56-
connection_kwargs = {"foo": "bar", "biz": "baz"}
57+
connection_kwargs = {
58+
"foo": "bar",
59+
"biz": "baz",
60+
"maintenance_state": MaintenanceState.NONE,
61+
}
5762
pool = self.get_pool(
5863
connection_kwargs=connection_kwargs, connection_class=DummyConnection
5964
)
65+
6066
connection = pool.get_connection()
6167
assert isinstance(connection, DummyConnection)
6268
assert connection.kwargs == connection_kwargs
@@ -152,7 +158,9 @@ def test_connection_creation(self, master_host):
152158
"host": master_host[0],
153159
"port": master_host[1],
154160
}
161+
155162
pool = self.get_pool(connection_kwargs=connection_kwargs)
163+
connection_kwargs["maintenance_state"] = MaintenanceState.NONE
156164
connection = pool.get_connection()
157165
assert isinstance(connection, DummyConnection)
158166
assert connection.kwargs == connection_kwargs

0 commit comments

Comments
 (0)