Skip to content

Commit f3285fd

Browse files
Olivier RoquesOlivier Roques
authored andcommitted
Format and clean code
1 parent b9b39c8 commit f3285fd

File tree

4 files changed

+60
-81
lines changed

4 files changed

+60
-81
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ ssh_controller.connect()
4343
```python
4444
return_code, output = ssh_controller.run(
4545
command="echo 'Hello world!' > /tmp/hello.txt",
46-
display=True, # display output, false by default
47-
capture_output=True, # save output, false by default
46+
display=True, # display output, false by default
47+
capture=True, # save output, false by default
4848
# request a shell to run the command, true by default
4949
shell=True,
5050
# combine stderr into stdout when shell=False, false by default
@@ -127,7 +127,7 @@ kwargs_sleep = {
127127
kwargs_ping = {
128128
"command": "echo 'thread ping: starting ping' && ping localhost",
129129
"display": True,
130-
"capture_output": True,
130+
"capture": True,
131131
"stop_event": stop_event_ping,
132132
}
133133

examples/demo.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,18 @@ def demo_key():
1616
ssh_controller = sshcontroller.SSHController(
1717
host=HOST_IP,
1818
user="olivier",
19-
key_path="~/.ssh/id_rsa", # if omitted, look for keys in SSH agent and in ~/.ssh/
20-
key_password=KEY_PWD, # optional
21-
key_type="rsa", # rsa (default), dsa, ecdsa or ed25519
22-
port=22, # 22 is the default
19+
key_path="~/.ssh/id_rsa", # if omitted, look in agent and in ~/.ssh/
20+
key_password=KEY_PWD, # optional
21+
key_type="rsa", # rsa (default), dsa, ecdsa or ed25519
22+
port=22, # 22 is the default
2323
)
2424

2525
ssh_controller.connect()
2626

2727
return_code, output = ssh_controller.run(
2828
command="echo 'Hello world!' > /tmp/hello.txt",
29-
display=True, # display output, false by default
30-
capture_output=True, # save output, false by default
29+
display=True, # display output, false by default
30+
capture=True, # save output, false by default
3131
# request a shell to run the command, true by default
3232
shell=True,
3333
# combine stderr into stdout when shell=False, false by default
@@ -50,11 +50,9 @@ def demo_key():
5050

5151

5252
def demo_pwd():
53-
ssh_controller = sshcontroller.SSHController(
54-
host=HOST_IP,
55-
user="olivier",
56-
ssh_password=SSH_PWD
57-
)
53+
ssh_controller = sshcontroller.SSHController(host=HOST_IP,
54+
user="olivier",
55+
ssh_password=SSH_PWD)
5856
ssh_controller.connect()
5957

6058
output = queue.Queue() # a queue to store the ping command output
@@ -69,18 +67,20 @@ def demo_pwd():
6967
kwargs_ping = {
7068
"command": "echo 'thread ping: starting ping' && ping localhost",
7169
"display": True,
72-
"capture_output": True,
70+
"capture": True,
7371
"stop_event": stop_event_ping,
7472
}
7573

7674
# call run() and store the command output in the queue
7775
def wrapper(kwargs):
7876
return output.put(ssh_controller.run(**kwargs))
7977

80-
thread_sleep = threading.Thread(
81-
target=ssh_controller.run, name="thread_sleep", kwargs=kwargs_sleep)
82-
thread_ping = threading.Thread(
83-
target=wrapper, name="thread_ping", args=(kwargs_ping, ))
78+
thread_sleep = threading.Thread(target=ssh_controller.run,
79+
name="thread_sleep",
80+
kwargs=kwargs_sleep)
81+
thread_ping = threading.Thread(target=wrapper,
82+
name="thread_ping",
83+
args=(kwargs_ping, ))
8484

8585
thread_ping.start()
8686
thread_sleep.start()

sshcontroller/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from sshcontroller.sshcontroller import SSHController, SFTPController
1+
from .sshcontroller import SSHController, SFTPController

sshcontroller/sshcontroller.py

Lines changed: 40 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -54,33 +54,28 @@ def __init__(
5454
):
5555
self.host = host
5656
self.user = user
57-
self.ssh_password = ssh_password if key_path is None else None
57+
self.ssh_password = ssh_password if not key_path else None
5858
self.port = port
5959
self.nb_bytes = 1024
60-
self.keys, self.transport = [], None
60+
self.keys = []
61+
self.transport = None
6162
key_type = key_type.lower()
6263

6364
if key_path:
64-
self.keys.append(
65-
_KEY_TYPES[key_type].from_private_key(
66-
open(path.expanduser(key_path), 'r'),
67-
key_password,
68-
)
69-
)
65+
key_file = open(path.expanduser(key_path), 'r')
66+
key = _KEY_TYPES[key_type].from_private_key(key_file, key_password)
67+
self.keys.append(key)
7068
elif ssh_password is None:
7169
self.keys = paramiko.Agent().get_keys()
72-
7370
try:
74-
key_file = _KEY_TYPES[key_type].from_private_key(
75-
open(path.expanduser(f"~/.ssh/id_{key_type}"), 'r'),
76-
key_password
77-
)
71+
key_file = open(path.expanduser(f"~/.ssh/id_{key_type}"), 'r')
72+
key = _KEY_TYPES[key_type].from_private_key(
73+
key_file, key_password)
7874
except Exception:
7975
pass
8076
else:
81-
self.keys.insert(
82-
len(self.keys) if key_password is None else 0, key_file
83-
)
77+
index = len(self.keys) if key_password is None else 0
78+
self.keys.insert(index, key)
8479

8580
if not self.keys:
8681
logging.error("No valid key found")
@@ -96,10 +91,8 @@ def connect(self):
9691

9792
if self.ssh_password is not None:
9893
try:
99-
self.transport.connect(
100-
username=self.user,
101-
password=self.ssh_password,
102-
)
94+
self.transport.connect(username=self.user,
95+
password=self.ssh_password)
10396
except paramiko.SSHException:
10497
pass
10598
else:
@@ -117,27 +110,24 @@ def connect(self):
117110
logging.info(f"Successfully connected to {self.user}@{self.host}")
118111
return 0
119112

120-
def __run_until_event(
113+
def _run_until_event(
121114
self,
122115
command,
123116
stop_event,
124117
display=True,
125-
capture_output=False,
118+
capture=False,
126119
shell=True,
127120
combine_stderr=False,
128121
):
122+
exit_code, output = 0, ""
129123
channel = self.transport.open_session()
130-
output = ""
131-
132124
channel.settimeout(2)
133125
channel.set_combine_stderr(combine_stderr)
134-
135126
if shell:
136127
channel.get_pty()
137-
138128
channel.exec_command(command)
139129

140-
if not display and not capture_output:
130+
if not display and not capture:
141131
stop_event.wait()
142132
else:
143133
while True:
@@ -148,102 +138,92 @@ def __run_until_event(
148138
break
149139
continue
150140

151-
if not len(raw_data):
141+
if not raw_data:
152142
break
153-
154143
data = raw_data.decode("utf-8")
155-
156144
if display:
157145
print(data, end='')
158-
159-
if capture_output:
146+
if capture:
160147
output += data
161-
162148
if stop_event.is_set():
163149
break
164150

165151
channel.close()
166152

167-
if not channel.exit_status_ready():
168-
return (0, output.splitlines())
153+
if channel.exit_status_ready():
154+
exit_code = channel.recv_exit_status()
169155

170-
return (channel.recv_exit_status(), output.splitlines())
156+
return (exit_code, output.splitlines())
171157

172-
def __run_until_exit(
158+
def _run_until_exit(
173159
self,
174160
command,
175161
timeout,
176162
display=True,
177-
capture_output=False,
163+
capture=False,
178164
shell=True,
179165
combine_stderr=False,
180166
):
167+
exit_code, output = 0, ""
181168
channel = self.transport.open_session()
182-
output = ""
183-
184169
channel.settimeout(timeout)
185170
channel.set_combine_stderr(combine_stderr)
186-
187171
if shell:
188172
channel.get_pty()
189-
190173
channel.exec_command(command)
191174

192175
try:
193-
if not display and not capture_output:
176+
if not display and not capture:
194177
return (channel.recv_exit_status(), output.splitlines())
195178
else:
196179
while True:
197180
raw_data = channel.recv(self.nb_bytes)
198-
199-
if not len(raw_data):
181+
if not raw_data:
200182
break
201-
202183
data = raw_data.decode("utf-8")
203-
204184
if display:
205185
print(data, end='')
206-
207-
if capture_output:
186+
if capture:
208187
output += data
209188
except socket.timeout:
210189
logging.warning(f"Timeout after {timeout}s")
211-
return (1, output.splitlines())
190+
exit_code = 1
212191
except KeyboardInterrupt:
213192
logging.info("KeyboardInterrupt")
214-
return (0, output.splitlines())
193+
exit_code = 0
194+
else:
195+
exit_code = channel.recv_exit_status()
215196
finally:
216197
channel.close()
217-
218-
return (channel.recv_exit_status(), output.splitlines())
198+
return (exit_code, output.splitlines())
219199

220200
def run(
221201
self,
222202
command,
223203
display=False,
224-
capture_output=False,
204+
capture=False,
225205
shell=True,
226206
combine_stderr=False,
227207
timeout=None,
228208
stop_event=None,
229209
):
230-
if stop_event:
231-
return self.__run_until_event(
210+
if stop_event is not None:
211+
return self._run_until_event(
232212
command,
233213
stop_event,
234214
display=display,
235215
shell=shell,
236216
combine_stderr=combine_stderr,
237-
capture_output=capture_output,
217+
capture=capture,
238218
)
239219
else:
240-
return self.__run_until_exit(
220+
return self._run_until_exit(
241221
command,
242222
timeout,
243223
display=display,
244224
shell=shell,
245225
combine_stderr=combine_stderr,
246-
capture_output=capture_output,
226+
capture=capture,
247227
)
248228

249229
def disconnect(self):
@@ -254,12 +234,11 @@ def __getattr__(self, target):
254234
def wrapper(*args, **kwargs):
255235
if not self.transport.is_authenticated():
256236
logging.error("SSH session is not ready")
257-
return 1
237+
return
258238

259239
sftp_channel = SFTPController.from_transport(self.transport)
260240
r = getattr(sftp_channel, target)(*args, **kwargs)
261241
sftp_channel.close()
262-
263242
return r
264243

265244
return wrapper

0 commit comments

Comments
 (0)