@@ -54,33 +54,28 @@ def __init__(
54
54
):
55
55
self .host = host
56
56
self .user = user
57
- self .ssh_password = ssh_password if key_path is None else None
57
+ self .ssh_password = ssh_password if not key_path else None
58
58
self .port = port
59
59
self .nb_bytes = 1024
60
- self .keys , self .transport = [], None
60
+ self .keys = []
61
+ self .transport = None
61
62
key_type = key_type .lower ()
62
63
63
64
if key_path :
64
- self .keys .append (
65
- _KEY_TYPES [key_type ].from_private_key (
66
- open (path .expanduser (key_path ), 'r' ),
67
- key_password ,
68
- )
69
- )
65
+ key_file = open (path .expanduser (key_path ), 'r' )
66
+ key = _KEY_TYPES [key_type ].from_private_key (key_file , key_password )
67
+ self .keys .append (key )
70
68
elif ssh_password is None :
71
69
self .keys = paramiko .Agent ().get_keys ()
72
-
73
70
try :
74
- key_file = _KEY_TYPES [key_type ].from_private_key (
75
- open (path .expanduser (f"~/.ssh/id_{ key_type } " ), 'r' ),
76
- key_password
77
- )
71
+ key_file = open (path .expanduser (f"~/.ssh/id_{ key_type } " ), 'r' )
72
+ key = _KEY_TYPES [key_type ].from_private_key (
73
+ key_file , key_password )
78
74
except Exception :
79
75
pass
80
76
else :
81
- self .keys .insert (
82
- len (self .keys ) if key_password is None else 0 , key_file
83
- )
77
+ index = len (self .keys ) if key_password is None else 0
78
+ self .keys .insert (index , key )
84
79
85
80
if not self .keys :
86
81
logging .error ("No valid key found" )
@@ -96,10 +91,8 @@ def connect(self):
96
91
97
92
if self .ssh_password is not None :
98
93
try :
99
- self .transport .connect (
100
- username = self .user ,
101
- password = self .ssh_password ,
102
- )
94
+ self .transport .connect (username = self .user ,
95
+ password = self .ssh_password )
103
96
except paramiko .SSHException :
104
97
pass
105
98
else :
@@ -117,27 +110,24 @@ def connect(self):
117
110
logging .info (f"Successfully connected to { self .user } @{ self .host } " )
118
111
return 0
119
112
120
- def __run_until_event (
113
+ def _run_until_event (
121
114
self ,
122
115
command ,
123
116
stop_event ,
124
117
display = True ,
125
- capture_output = False ,
118
+ capture = False ,
126
119
shell = True ,
127
120
combine_stderr = False ,
128
121
):
122
+ exit_code , output = 0 , ""
129
123
channel = self .transport .open_session ()
130
- output = ""
131
-
132
124
channel .settimeout (2 )
133
125
channel .set_combine_stderr (combine_stderr )
134
-
135
126
if shell :
136
127
channel .get_pty ()
137
-
138
128
channel .exec_command (command )
139
129
140
- if not display and not capture_output :
130
+ if not display and not capture :
141
131
stop_event .wait ()
142
132
else :
143
133
while True :
@@ -148,102 +138,92 @@ def __run_until_event(
148
138
break
149
139
continue
150
140
151
- if not len ( raw_data ) :
141
+ if not raw_data :
152
142
break
153
-
154
143
data = raw_data .decode ("utf-8" )
155
-
156
144
if display :
157
145
print (data , end = '' )
158
-
159
- if capture_output :
146
+ if capture :
160
147
output += data
161
-
162
148
if stop_event .is_set ():
163
149
break
164
150
165
151
channel .close ()
166
152
167
- if not channel .exit_status_ready ():
168
- return ( 0 , output . splitlines () )
153
+ if channel .exit_status_ready ():
154
+ exit_code = channel . recv_exit_status ( )
169
155
170
- return (channel . recv_exit_status () , output .splitlines ())
156
+ return (exit_code , output .splitlines ())
171
157
172
- def __run_until_exit (
158
+ def _run_until_exit (
173
159
self ,
174
160
command ,
175
161
timeout ,
176
162
display = True ,
177
- capture_output = False ,
163
+ capture = False ,
178
164
shell = True ,
179
165
combine_stderr = False ,
180
166
):
167
+ exit_code , output = 0 , ""
181
168
channel = self .transport .open_session ()
182
- output = ""
183
-
184
169
channel .settimeout (timeout )
185
170
channel .set_combine_stderr (combine_stderr )
186
-
187
171
if shell :
188
172
channel .get_pty ()
189
-
190
173
channel .exec_command (command )
191
174
192
175
try :
193
- if not display and not capture_output :
176
+ if not display and not capture :
194
177
return (channel .recv_exit_status (), output .splitlines ())
195
178
else :
196
179
while True :
197
180
raw_data = channel .recv (self .nb_bytes )
198
-
199
- if not len (raw_data ):
181
+ if not raw_data :
200
182
break
201
-
202
183
data = raw_data .decode ("utf-8" )
203
-
204
184
if display :
205
185
print (data , end = '' )
206
-
207
- if capture_output :
186
+ if capture :
208
187
output += data
209
188
except socket .timeout :
210
189
logging .warning (f"Timeout after { timeout } s" )
211
- return ( 1 , output . splitlines ())
190
+ exit_code = 1
212
191
except KeyboardInterrupt :
213
192
logging .info ("KeyboardInterrupt" )
214
- return (0 , output .splitlines ())
193
+ exit_code = 0
194
+ else :
195
+ exit_code = channel .recv_exit_status ()
215
196
finally :
216
197
channel .close ()
217
-
218
- return (channel .recv_exit_status (), output .splitlines ())
198
+ return (exit_code , output .splitlines ())
219
199
220
200
def run (
221
201
self ,
222
202
command ,
223
203
display = False ,
224
- capture_output = False ,
204
+ capture = False ,
225
205
shell = True ,
226
206
combine_stderr = False ,
227
207
timeout = None ,
228
208
stop_event = None ,
229
209
):
230
- if stop_event :
231
- return self .__run_until_event (
210
+ if stop_event is not None :
211
+ return self ._run_until_event (
232
212
command ,
233
213
stop_event ,
234
214
display = display ,
235
215
shell = shell ,
236
216
combine_stderr = combine_stderr ,
237
- capture_output = capture_output ,
217
+ capture = capture ,
238
218
)
239
219
else :
240
- return self .__run_until_exit (
220
+ return self ._run_until_exit (
241
221
command ,
242
222
timeout ,
243
223
display = display ,
244
224
shell = shell ,
245
225
combine_stderr = combine_stderr ,
246
- capture_output = capture_output ,
226
+ capture = capture ,
247
227
)
248
228
249
229
def disconnect (self ):
@@ -254,12 +234,11 @@ def __getattr__(self, target):
254
234
def wrapper (* args , ** kwargs ):
255
235
if not self .transport .is_authenticated ():
256
236
logging .error ("SSH session is not ready" )
257
- return 1
237
+ return
258
238
259
239
sftp_channel = SFTPController .from_transport (self .transport )
260
240
r = getattr (sftp_channel , target )(* args , ** kwargs )
261
241
sftp_channel .close ()
262
-
263
242
return r
264
243
265
244
return wrapper
0 commit comments