Skip to content

Commit 2f03e0f

Browse files
authored
ssh client: implement password_auth_requested (#403)
1 parent 1b82e4e commit 2f03e0f

File tree

2 files changed

+55
-22
lines changed

2 files changed

+55
-22
lines changed

src/scmrepo/git/backend/dulwich/asyncssh_vendor.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ async def _read_all(read: Callable[[int], Coroutine], n: Optional[int] = None) -
4141
return b"".join(result)
4242

4343

44+
async def _getpass(*args, **kwargs) -> str:
45+
from getpass import getpass
46+
47+
return await asyncio.to_thread(getpass, *args, **kwargs)
48+
49+
4450
class _StderrWrapper:
4551
def __init__(self, stderr: "SSHReader", loop: asyncio.AbstractEventLoop) -> None:
4652
self.stderr = stderr
@@ -164,8 +170,6 @@ async def public_key_auth_requested( # noqa: C901
164170
return None
165171

166172
async def _read_private_key_interactive(self, path: "FilePath") -> "SSHKey":
167-
from getpass import getpass
168-
169173
from asyncssh.public_key import (
170174
KeyEncryptionError,
171175
KeyImportError,
@@ -177,11 +181,8 @@ async def _read_private_key_interactive(self, path: "FilePath") -> "SSHKey":
177181
if passphrase:
178182
return read_private_key(path, passphrase=passphrase)
179183

180-
loop = asyncio.get_running_loop()
181184
for _ in range(3):
182-
passphrase = await loop.run_in_executor(
183-
None, getpass, f"Enter passphrase for key '{path}': "
184-
)
185+
passphrase = await _getpass(f"Enter passphrase for key {path!r}: ")
185186
if passphrase:
186187
try:
187188
key = read_private_key(path, passphrase=passphrase)
@@ -201,23 +202,20 @@ async def kbdint_challenge_received( # pylint: disable=invalid-overridden-metho
201202
lang: str,
202203
prompts: "KbdIntPrompts",
203204
) -> Optional["KbdIntResponse"]:
204-
from getpass import getpass
205-
206205
if os.environ.get("GIT_TERMINAL_PROMPT") == "0":
207206
return None
208207

209-
def _getpass(prompt: str) -> str:
210-
return getpass(prompt=prompt).rstrip()
211-
212208
if instructions:
213209
pass
214-
loop = asyncio.get_running_loop()
215-
return [
216-
await loop.run_in_executor(
217-
None, _getpass, f"({name}) {prompt}" if name else prompt
218-
)
219-
for prompt, _ in prompts
220-
]
210+
211+
response: list[str] = []
212+
for prompt, _echo in prompts:
213+
p = await _getpass(f"({name}) {prompt}" if name else prompt)
214+
response.append(p.rstrip())
215+
return response
216+
217+
async def password_auth_requested(self) -> str:
218+
return await _getpass()
221219

222220

223221
class AsyncSSHVendor(BaseAsyncObject, SSHVendor):

tests/test_dulwich.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import asyncssh
99
import paramiko
1010
import pytest
11+
from paramiko.server import InteractiveQuery
1112
from pytest_mock import MockerFixture
1213
from pytest_test_utils.waiters import wait_until
1314

@@ -52,13 +53,24 @@ class Server(paramiko.ServerInterface):
5253
"""http://docs.paramiko.org/en/2.4/api/server.html."""
5354

5455
def __init__(self, commands, *args, **kwargs) -> None:
55-
super().__init__(*args, **kwargs)
56+
super().__init__()
5657
self.commands = commands
58+
self.allowed_auths = kwargs.get("allowed_auths", "publickey,password")
5759

5860
def check_channel_exec_request(self, channel, command):
5961
self.commands.append(command)
6062
return True
6163

64+
def check_auth_interactive(self, username: str, submethods: str):
65+
return InteractiveQuery(
66+
"Password", "Enter the password", f"Password for user {USER}:"
67+
)
68+
69+
def check_auth_interactive_response(self, responses):
70+
if responses[0] == PASSWORD:
71+
return paramiko.AUTH_SUCCESSFUL
72+
return paramiko.AUTH_FAILED
73+
6274
def check_auth_password(self, username, password):
6375
if username == USER and password == PASSWORD:
6476
return paramiko.AUTH_SUCCESSFUL
@@ -76,12 +88,12 @@ def check_channel_request(self, kind, chanid):
7688
return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
7789

7890
def get_allowed_auths(self, username):
79-
return "password,publickey"
91+
return self.allowed_auths
8092

8193

8294
@pytest.fixture
8395
def ssh_conn(request: pytest.FixtureRequest) -> dict[str, Any]:
84-
server = Server([])
96+
server = Server([], **getattr(request, "param", {}))
8597

8698
socket.setdefaulttimeout(10)
8799
request.addfinalizer(lambda: socket.setdefaulttimeout(None))
@@ -133,7 +145,8 @@ def test_run_command_password(server: Server, ssh_port: int):
133145
assert b"test_run_command_password" in server.commands
134146

135147

136-
def test_run_command_no_password(server: Server, ssh_port: int):
148+
@pytest.mark.parametrize("ssh_conn", [{"allowed_auths": "publickey"}], indirect=True)
149+
def test_run_command_no_password(ssh_port: int):
137150
vendor = AsyncSSHVendor()
138151
with pytest.raises(AuthError):
139152
vendor.run_command(
@@ -145,6 +158,28 @@ def test_run_command_no_password(server: Server, ssh_port: int):
145158
)
146159

147160

161+
@pytest.mark.parametrize(
162+
"ssh_conn",
163+
[{"allowed_auths": "password"}, {"allowed_auths": "keyboard-interactive"}],
164+
indirect=True,
165+
ids=["password", "interactive"],
166+
)
167+
def test_should_prompt_for_password_when_no_password_passed(
168+
mocker: MockerFixture, server: Server, ssh_port: int
169+
):
170+
mocked_getpass = mocker.patch("getpass.getpass", return_value=PASSWORD)
171+
vendor = AsyncSSHVendor()
172+
vendor.run_command(
173+
"127.0.0.1",
174+
"test_run_command_password",
175+
username=USER,
176+
port=ssh_port,
177+
password=None,
178+
)
179+
assert server.commands == [b"test_run_command_password"]
180+
mocked_getpass.asssert_called_once()
181+
182+
148183
def test_run_command_with_privkey(server: Server, ssh_port: int):
149184
key = asyncssh.import_private_key(CLIENT_KEY)
150185

0 commit comments

Comments
 (0)