8
8
import asyncssh
9
9
import paramiko
10
10
import pytest
11
+ from paramiko .server import InteractiveQuery
11
12
from pytest_mock import MockerFixture
12
13
from pytest_test_utils .waiters import wait_until
13
14
@@ -52,13 +53,24 @@ class Server(paramiko.ServerInterface):
52
53
"""http://docs.paramiko.org/en/2.4/api/server.html."""
53
54
54
55
def __init__ (self , commands , * args , ** kwargs ) -> None :
55
- super ().__init__ (* args , ** kwargs )
56
+ super ().__init__ ()
56
57
self .commands = commands
58
+ self .allowed_auths = kwargs .get ("allowed_auths" , "publickey,password" )
57
59
58
60
def check_channel_exec_request (self , channel , command ):
59
61
self .commands .append (command )
60
62
return True
61
63
64
+ def check_auth_interactive (self , username : str , submethods : str ):
65
+ return InteractiveQuery (
66
+ "Password" , "Enter the password" , f"Password for user { USER } :"
67
+ )
68
+
69
+ def check_auth_interactive_response (self , responses ):
70
+ if responses [0 ] == PASSWORD :
71
+ return paramiko .AUTH_SUCCESSFUL
72
+ return paramiko .AUTH_FAILED
73
+
62
74
def check_auth_password (self , username , password ):
63
75
if username == USER and password == PASSWORD :
64
76
return paramiko .AUTH_SUCCESSFUL
@@ -76,12 +88,12 @@ def check_channel_request(self, kind, chanid):
76
88
return paramiko .OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
77
89
78
90
def get_allowed_auths (self , username ):
79
- return "password,publickey"
91
+ return self . allowed_auths
80
92
81
93
82
94
@pytest .fixture
83
95
def ssh_conn (request : pytest .FixtureRequest ) -> dict [str , Any ]:
84
- server = Server ([])
96
+ server = Server ([], ** getattr ( request , "param" , {}) )
85
97
86
98
socket .setdefaulttimeout (10 )
87
99
request .addfinalizer (lambda : socket .setdefaulttimeout (None ))
@@ -133,7 +145,8 @@ def test_run_command_password(server: Server, ssh_port: int):
133
145
assert b"test_run_command_password" in server .commands
134
146
135
147
136
- def test_run_command_no_password (server : Server , ssh_port : int ):
148
+ @pytest .mark .parametrize ("ssh_conn" , [{"allowed_auths" : "publickey" }], indirect = True )
149
+ def test_run_command_no_password (ssh_port : int ):
137
150
vendor = AsyncSSHVendor ()
138
151
with pytest .raises (AuthError ):
139
152
vendor .run_command (
@@ -145,6 +158,28 @@ def test_run_command_no_password(server: Server, ssh_port: int):
145
158
)
146
159
147
160
161
+ @pytest .mark .parametrize (
162
+ "ssh_conn" ,
163
+ [{"allowed_auths" : "password" }, {"allowed_auths" : "keyboard-interactive" }],
164
+ indirect = True ,
165
+ ids = ["password" , "interactive" ],
166
+ )
167
+ def test_should_prompt_for_password_when_no_password_passed (
168
+ mocker : MockerFixture , server : Server , ssh_port : int
169
+ ):
170
+ mocked_getpass = mocker .patch ("getpass.getpass" , return_value = PASSWORD )
171
+ vendor = AsyncSSHVendor ()
172
+ vendor .run_command (
173
+ "127.0.0.1" ,
174
+ "test_run_command_password" ,
175
+ username = USER ,
176
+ port = ssh_port ,
177
+ password = None ,
178
+ )
179
+ assert server .commands == [b"test_run_command_password" ]
180
+ mocked_getpass .asssert_called_once ()
181
+
182
+
148
183
def test_run_command_with_privkey (server : Server , ssh_port : int ):
149
184
key = asyncssh .import_private_key (CLIENT_KEY )
150
185
0 commit comments