Skip to content
2 changes: 1 addition & 1 deletion pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2825,7 +2825,7 @@ async def run(self) -> T:
if self._last_error is None:
self._last_error = exc

if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
if self._server is not None:
self._deprioritized_servers.append(self._server)

def _is_not_eligible_for_retry(self) -> bool:
Expand Down
37 changes: 20 additions & 17 deletions pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ async def select_servers(
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
operation_id: Optional[int] = None,
deprioritized_servers: Optional[list[Server]] = None,
) -> list[Server]:
"""Return a list of Servers matching selector, or time out.

Expand Down Expand Up @@ -292,7 +293,12 @@ async def select_servers(

async with self._lock:
server_descriptions = await self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
selector,
server_timeout,
operation,
operation_id,
address,
deprioritized_servers=deprioritized_servers,
)

return [
Expand All @@ -306,6 +312,7 @@ async def _select_servers_loop(
operation: str,
operation_id: Optional[int],
address: Optional[_Address],
deprioritized_servers: Optional[list[Server]] = None,
) -> list[ServerDescription]:
"""select_servers() guts. Hold the lock when calling this."""
now = time.monotonic()
Expand All @@ -324,7 +331,12 @@ async def _select_servers_loop(
)

server_descriptions = self._description.apply_selector(
selector, address, custom_selector=self._settings.server_selector
selector,
address,
custom_selector=self._settings.server_selector,
deprioritized_servers=[server.description for server in deprioritized_servers]
if deprioritized_servers
else None,
)

while not server_descriptions:
Expand Down Expand Up @@ -385,9 +397,13 @@ async def _select_server(
operation_id: Optional[int] = None,
) -> Server:
servers = await self.select_servers(
selector, operation, server_selection_timeout, address, operation_id
selector,
operation,
server_selection_timeout,
address,
operation_id,
deprioritized_servers,
)
servers = _filter_servers(servers, deprioritized_servers)
if len(servers) == 1:
return servers[0]
server1, server2 = random.sample(servers, 2)
Expand Down Expand Up @@ -1112,16 +1128,3 @@ def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDe
if current_tv["processId"] != new_tv["processId"]:
return False
return current_tv["counter"] > new_tv["counter"]


def _filter_servers(
candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None
) -> list[Server]:
"""Filter out deprioritized servers from a list of server candidates."""
if not deprioritized_servers:
return candidates

filtered = [server for server in candidates if server not in deprioritized_servers]

# If not possible to pick a prioritized server, return the original list
return filtered or candidates
6 changes: 3 additions & 3 deletions pymongo/server_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ class Selection:

@classmethod
def from_topology_description(cls, topology_description: TopologyDescription) -> Selection:
known_servers = topology_description.known_servers
candidate_servers = topology_description.candidate_servers
primary = None
for sd in known_servers:
for sd in candidate_servers:
if sd.server_type == SERVER_TYPE.RSPrimary:
primary = sd
break

return Selection(
topology_description,
topology_description.known_servers,
topology_description.candidate_servers,
topology_description.common_wire_version,
primary,
)
Expand Down
2 changes: 1 addition & 1 deletion pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2815,7 +2815,7 @@ def run(self) -> T:
if self._last_error is None:
self._last_error = exc

if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
if self._server is not None:
self._deprioritized_servers.append(self._server)

def _is_not_eligible_for_retry(self) -> bool:
Expand Down
37 changes: 20 additions & 17 deletions pymongo/synchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def select_servers(
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
operation_id: Optional[int] = None,
deprioritized_servers: Optional[list[Server]] = None,
) -> list[Server]:
"""Return a list of Servers matching selector, or time out.

Expand Down Expand Up @@ -292,7 +293,12 @@ def select_servers(

with self._lock:
server_descriptions = self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
selector,
server_timeout,
operation,
operation_id,
address,
deprioritized_servers=deprioritized_servers,
)

return [
Expand All @@ -306,6 +312,7 @@ def _select_servers_loop(
operation: str,
operation_id: Optional[int],
address: Optional[_Address],
deprioritized_servers: Optional[list[Server]] = None,
) -> list[ServerDescription]:
"""select_servers() guts. Hold the lock when calling this."""
now = time.monotonic()
Expand All @@ -324,7 +331,12 @@ def _select_servers_loop(
)

server_descriptions = self._description.apply_selector(
selector, address, custom_selector=self._settings.server_selector
selector,
address,
custom_selector=self._settings.server_selector,
deprioritized_servers=[server.description for server in deprioritized_servers]
if deprioritized_servers
else None,
)

while not server_descriptions:
Expand Down Expand Up @@ -385,9 +397,13 @@ def _select_server(
operation_id: Optional[int] = None,
) -> Server:
servers = self.select_servers(
selector, operation, server_selection_timeout, address, operation_id
selector,
operation,
server_selection_timeout,
address,
operation_id,
deprioritized_servers,
)
servers = _filter_servers(servers, deprioritized_servers)
if len(servers) == 1:
return servers[0]
server1, server2 = random.sample(servers, 2)
Expand Down Expand Up @@ -1110,16 +1126,3 @@ def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDe
if current_tv["processId"] != new_tv["processId"]:
return False
return current_tv["counter"] > new_tv["counter"]


def _filter_servers(
candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None
) -> list[Server]:
"""Filter out deprioritized servers from a list of server candidates."""
if not deprioritized_servers:
return candidates

filtered = [server for server in candidates if server not in deprioritized_servers]

# If not possible to pick a prioritized server, return the original list
return filtered or candidates
22 changes: 21 additions & 1 deletion pymongo/topology_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
self._server_descriptions = server_descriptions
self._max_set_version = max_set_version
self._max_election_id = max_election_id
self._candidate_servers = list(self._server_descriptions.values())

# The heartbeat_frequency is used in staleness estimates.
self._topology_settings = topology_settings
Expand Down Expand Up @@ -248,6 +249,11 @@ def readable_servers(self) -> list[ServerDescription]:
"""List of readable Servers."""
return [s for s in self._server_descriptions.values() if s.is_readable]

@property
def candidate_servers(self) -> list[ServerDescription]:
"""List of Servers excluding deprioritized servers."""
return self._candidate_servers

@property
def common_wire_version(self) -> Optional[int]:
"""Minimum of all servers' max wire versions, or None."""
Expand Down Expand Up @@ -283,11 +289,24 @@ def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerD
if (cast(float, s.round_trip_time) - fastest) <= threshold
]

def _filter_servers(
self, deprioritized_servers: Optional[list[ServerDescription]] = None
) -> None:
"""Filter out deprioritized servers from a list of server candidates."""
if not deprioritized_servers:
self._candidate_servers = self.known_servers
else:
filtered = [
server for server in self.known_servers if server not in deprioritized_servers
]
self._candidate_servers = filtered or self.known_servers

def apply_selector(
self,
selector: Any,
address: Optional[_Address] = None,
custom_selector: Optional[_ServerSelector] = None,
deprioritized_servers: Optional[list[ServerDescription]] = None,
) -> list[ServerDescription]:
"""List of servers matching the provided selector(s).

Expand Down Expand Up @@ -324,9 +343,10 @@ def apply_selector(
description = self.server_descriptions().get(address)
return [description] if description and description.is_server_type_known else []

self._filter_servers(deprioritized_servers)
# Primary selection fast path.
if self.topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary and type(selector) is Primary:
for sd in self._server_descriptions.values():
for sd in self._candidate_servers:
if sd.server_type == SERVER_TYPE.RSPrimary:
sds = [sd]
if custom_selector:
Expand Down
71 changes: 71 additions & 0 deletions test/asynchronous/test_retryable_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import threading
from test.asynchronous.utils import async_set_fail_point

from pymongo import ReadPreference
from pymongo.errors import OperationFailure

sys.path[0:0] = [""]
Expand Down Expand Up @@ -182,6 +183,44 @@ async def test_retryable_reads_are_retried_on_a_different_mongos_when_one_is_ava
# Assert that both events occurred on different mongos.
assert listener.failed_events[0].connection_id != listener.failed_events[1].connection_id

@async_client_context.require_replica_set
@async_client_context.require_failCommand_fail_point
async def test_retryable_reads_are_retried_on_a_different_replica_when_one_is_available(self):
fail_command = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 6},
}

replica_clients = []

for node in async_client_context.nodes:
client = await self.async_rs_or_single_client(*node, directConnection=True)
await async_set_fail_point(client, fail_command)
replica_clients.append(client)

listener = OvertCommandListener()
client = await self.async_rs_or_single_client(
event_listeners=[listener],
retryReads=True,
directConnection=False,
readPreference="secondaryPreferred",
)

with self.assertRaises(OperationFailure):
await client.t.t.find_one({})

# Disable failpoints on each node
for client in replica_clients:
fail_command["mode"] = "off"
await async_set_fail_point(client, fail_command)

self.assertEqual(len(listener.failed_events), 2)
self.assertEqual(len(listener.succeeded_events), 0)

# Assert that both events occurred on different nodes.
assert listener.failed_events[0].connection_id != listener.failed_events[1].connection_id

@async_client_context.require_multiple_mongoses
@async_client_context.require_failCommand_fail_point
async def test_retryable_reads_are_retried_on_the_same_mongos_when_no_others_are_available(
Expand Down Expand Up @@ -218,6 +257,38 @@ async def test_retryable_reads_are_retried_on_the_same_mongos_when_no_others_are
# Assert that both events occurred on the same mongos.
assert listener.succeeded_events[0].connection_id == listener.failed_events[0].connection_id

@async_client_context.require_replica_set
@async_client_context.require_failCommand_fail_point
async def test_retryable_reads_are_retried_on_the_same_replica_when_no_others_are_available(
self
):
fail_command = {
"configureFailPoint": "failCommand",
"mode": {"times": 1},
"data": {"failCommands": ["find"], "errorCode": 6},
}

node_client = await self.async_rs_or_single_client(*list(async_client_context.nodes)[0])
await async_set_fail_point(node_client, fail_command)

listener = OvertCommandListener()
client = await self.async_rs_or_single_client(
event_listeners=[listener],
retryReads=True,
)

await client.t.t.find_one({})

# Disable failpoints
fail_command["mode"] = "off"
await async_set_fail_point(node_client, fail_command)

self.assertEqual(len(listener.failed_events), 1)
self.assertEqual(len(listener.succeeded_events), 1)

# Assert that both events occurred on the same node.
assert listener.succeeded_events[0].connection_id == listener.failed_events[0].connection_id

@async_client_context.require_failCommand_fail_point
async def test_retryable_reads_are_retried_on_the_same_implicit_session(self):
listener = OvertCommandListener()
Expand Down
Loading
Loading