Skip to content

Commit 6677c61

Browse files
authored
connectors.ssh: add workaround for paramiko no session error
See: paramiko/paramiko#1390
1 parent 6428eed commit 6677c61

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

src/pyinfra/connectors/ssh.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import click
1111
from paramiko import AuthenticationException, BadHostKeyException, SFTPClient, SSHException
12+
from paramiko.agent import Agent
1213
from typing_extensions import TypedDict, Unpack, override
1314

1415
from pyinfra import logger
@@ -286,10 +287,64 @@ def _connect(self) -> None:
286287
f"Host key for {e.hostname} does not match.",
287288
)
288289

290+
except SSHException as e:
291+
if self._retry_paramiko_agent_keys(hostname, kwargs, e):
292+
return
293+
raise
294+
289295
@override
290296
def disconnect(self) -> None:
291297
self.get_file_transfer_connection.cache.clear()
292298

299+
def _retry_paramiko_agent_keys(
300+
self,
301+
hostname: str,
302+
kwargs: dict[str, Any],
303+
error: SSHException,
304+
) -> bool:
305+
# Workaround for Paramiko multi-key bug (paramiko/paramiko#1390).
306+
if "no existing session" not in str(error).lower():
307+
return False
308+
309+
if not kwargs.get("allow_agent"):
310+
return False
311+
312+
try:
313+
agent_keys = list(Agent().get_keys())
314+
except Exception:
315+
return False
316+
317+
if not agent_keys:
318+
return False
319+
320+
# Skip the first agent key, since Paramiko already attempted it
321+
attempt_keys = agent_keys[1:] if len(agent_keys) > 1 else agent_keys
322+
323+
for agent_key in attempt_keys:
324+
if self.client is not None:
325+
try:
326+
self.client.close()
327+
except Exception:
328+
pass
329+
330+
self.client = SSHClient()
331+
332+
single_key_kwargs = dict(kwargs)
333+
single_key_kwargs["allow_agent"] = False
334+
single_key_kwargs["pkey"] = agent_key
335+
336+
try:
337+
self.client.connect(hostname, **single_key_kwargs)
338+
return True
339+
except AuthenticationException:
340+
continue
341+
except SSHException as retry_error:
342+
if "no existing session" in str(retry_error).lower():
343+
continue
344+
raise retry_error
345+
346+
return False
347+
293348
@override
294349
def run_shell_command(
295350
self,

tests/test_connectors/test_ssh.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pyinfra.api.connect import connect_all
1111
from pyinfra.api.exceptions import ConnectError, PyinfraError
1212
from pyinfra.context import ctx_state
13+
from pyinfra.connectors import ssh
1314

1415
from ..util import make_inventory
1516

@@ -119,6 +120,94 @@ def test_connect_with_rsa_ssh_key(self):
119120

120121
connect_all(second_state)
121122

123+
def test_retry_paramiko_agent_keys_single_key(self):
124+
connector = ssh.SSHConnector.__new__(ssh.SSHConnector)
125+
connector.client = mock.Mock()
126+
127+
attempts = []
128+
connect_outcomes = [None]
129+
130+
def make_client():
131+
client = mock.Mock()
132+
133+
def fake_connect(hostname, **kwargs):
134+
attempts.append(dict(kwargs))
135+
outcome = connect_outcomes.pop(0)
136+
if isinstance(outcome, Exception):
137+
raise outcome
138+
139+
client.connect.side_effect = fake_connect
140+
client.close = mock.Mock()
141+
return client
142+
143+
with (
144+
mock.patch("pyinfra.connectors.ssh.Agent") as fake_agent,
145+
mock.patch("pyinfra.connectors.ssh.SSHClient", side_effect=make_client),
146+
):
147+
fake_agent.return_value.get_keys.return_value = ["key-one"]
148+
149+
result = connector._retry_paramiko_agent_keys(
150+
"host",
151+
{"allow_agent": True},
152+
SSHException("No existing session"),
153+
)
154+
155+
self.assertTrue(result)
156+
self.assertEqual(
157+
attempts,
158+
[
159+
{"allow_agent": False, "pkey": "key-one"},
160+
],
161+
)
162+
163+
def test_retry_paramiko_agent_keys_returns_false_without_keys(self):
164+
connector = ssh.SSHConnector.__new__(ssh.SSHConnector)
165+
connector.client = mock.Mock()
166+
167+
with mock.patch("pyinfra.connectors.ssh.Agent") as fake_agent:
168+
fake_agent.return_value.get_keys.return_value = []
169+
170+
result = connector._retry_paramiko_agent_keys(
171+
"host",
172+
{"allow_agent": True},
173+
SSHException("No existing session"),
174+
)
175+
176+
self.assertFalse(result)
177+
178+
@mock.patch("pyinfra.connectors.ssh.Agent")
179+
def test_connect_retries_agent_keys_after_paramiko_failure(self, fake_agent):
180+
key_one = mock.Mock(name="agent-key-1")
181+
key_two = mock.Mock(name="agent-key-2")
182+
fake_agent.return_value.get_keys.return_value = [key_one, key_two]
183+
184+
connect_calls = []
185+
186+
def fake_connect(hostname, **kwargs):
187+
connect_calls.append((hostname, dict(kwargs)))
188+
if len(connect_calls) == 1:
189+
raise SSHException("No existing session")
190+
191+
self.fake_connect_mock.side_effect = fake_connect
192+
193+
inventory = make_inventory(hosts=("somehost",))
194+
state = State(inventory, Config())
195+
196+
connect_all(state)
197+
198+
self.assertEqual(len(state.active_hosts), 1)
199+
self.assertEqual(len(connect_calls), 2)
200+
201+
first_hostname, first_kwargs = connect_calls[0]
202+
self.assertEqual(first_hostname, "somehost")
203+
self.assertTrue(first_kwargs.get("allow_agent"))
204+
self.assertNotIn("pkey", first_kwargs)
205+
206+
second_hostname, second_kwargs = connect_calls[1]
207+
self.assertEqual(second_hostname, "somehost")
208+
self.assertFalse(second_kwargs.get("allow_agent"))
209+
self.assertIs(second_kwargs.get("pkey"), key_two)
210+
122211
def test_connect_with_rsa_ssh_key_password(self):
123212
state = State(
124213
make_inventory(

0 commit comments

Comments
 (0)