1 """
2 ssh.py
3 Friendly Python SSH2 interface.
4 From http://commandline.org.uk/code/
5 License: LGPL
6 modified by justin riley (justin.t.riley@gmail.com)
7 """
8
9 import os
10 import re
11 import sys
12 import stat
13 import glob
14 import string
15 import socket
16 import fnmatch
17 import hashlib
18 import posixpath
19
20 import paramiko
21 from paramiko import util
22 from paramiko import RSAKey
23 from pyasn1.codec.der import encoder
24 from pyasn1.type import univ
25
26
27 try:
28 import termios
29 import tty
30 HAS_TERMIOS = True
31 except ImportError:
32 HAS_TERMIOS = False
33
34 from starcluster import scp
35 from starcluster import exception
36 from starcluster import progressbar
37 from starcluster.logger import log
41 """
42 Establishes an SSH connection to a remote host using either password or
43 private key authentication. Once established, this object allows executing
44 commands, copying files to/from the remote host, various file querying
45 similar to os.path.*, and much more.
46 """
47
48 - def __init__(self,
49 host,
50 username=None,
51 password=None,
52 private_key=None,
53 private_key_pass=None,
54 port=22,
55 timeout=30):
56 self._host = host
57 self._port = 22
58 self._pkey = None
59 self._username = username or os.environ['LOGNAME']
60 self._password = password
61 self._timeout = timeout
62 self._sftp = None
63 self._scp = None
64 self._transport = None
65 self._progress_bar = None
66 if private_key:
67 self._pkey = self.load_private_key(private_key, private_key_pass)
68 elif not password:
69 raise exception.SSHNoCredentialsError()
70 self._glob = SSHGlob(self)
71 self.__last_status = None
72
74
75 log.debug('loading private key %s' % private_key)
76 if private_key.endswith('rsa') or private_key.count('rsa'):
77 pkey = self._load_rsa_key(private_key, private_key_pass)
78 elif private_key.endswith('dsa') or private_key.count('dsa'):
79 pkey = self._load_dsa_key(private_key, private_key_pass)
80 else:
81 log.debug(
82 "specified key does not end in either rsa or dsa, trying both")
83 pkey = self._load_rsa_key(private_key, private_key_pass)
84 if pkey is None:
85 pkey = self._load_dsa_key(private_key, private_key_pass)
86 return pkey
87
88 - def connect(self, host=None, username=None, password=None,
89 private_key=None, private_key_pass=None, port=22, timeout=30):
90 host = host or self._host
91 username = username or self._username
92 password = password or self._password
93 pkey = self._pkey
94 if private_key:
95 pkey = self.load_private_key(private_key, private_key_pass)
96 log.debug("connecting to host %s on port %d as user %s" % (host, port,
97 username))
98 try:
99 sock = self._get_socket(host, port)
100 transport = paramiko.Transport(sock)
101 transport.banner_timeout = timeout
102 except socket.error:
103 raise exception.SSHConnectionError(host, port)
104
105 try:
106 transport.connect(username=username, pkey=pkey, password=password)
107 except paramiko.AuthenticationException:
108 raise exception.SSHAuthException(username, host)
109 except paramiko.SSHException, e:
110 msg = e.args[0]
111 raise exception.SSHError(msg)
112 except socket.error:
113 raise exception.SSHConnectionError(host, port)
114 except EOFError:
115 raise exception.SSHConnectionError(host, port)
116 except Exception, e:
117 raise exception.SSHError(str(e))
118 self.close()
119 self._transport = transport
120 return self
121
122 @property
124 """
125 This property attempts to return an active SSH transport
126 """
127 if not self._transport or not self._transport.is_active():
128 self.connect(self._host, self._username, self._password,
129 port=self._port, timeout=self._timeout)
130 return self._transport
131
133 return self.transport.get_remote_server_key()
134
136 if self._transport:
137 return self._transport.is_active()
138 return False
139
141 addrinfo = socket.getaddrinfo(hostname, port, socket.AF_UNSPEC,
142 socket.SOCK_STREAM)
143 for (family, socktype, proto, canonname, sockaddr) in addrinfo:
144 if socktype == socket.SOCK_STREAM:
145 af = family
146 break
147 else:
148 raise exception.SSHError(
149 'No suitable address family for %s' % hostname)
150 sock = socket.socket(af, socket.SOCK_STREAM)
151 sock.settimeout(self._timeout)
152 sock.connect((hostname, port))
153 return sock
154
156 private_key_file = os.path.expanduser(private_key)
157 try:
158 rsa_key = paramiko.RSAKey.from_private_key_file(private_key_file,
159 private_key_pass)
160 log.debug("Using private key %s (rsa)" % private_key)
161 return rsa_key
162 except paramiko.SSHException:
163 log.error('invalid rsa key or passphrase specified')
164
166 private_key_file = os.path.expanduser(private_key)
167 try:
168 dsa_key = paramiko.DSSKey.from_private_key_file(private_key_file,
169 private_key_pass)
170 log.info("Using private key %s (dsa)" % private_key)
171 return dsa_key
172 except paramiko.SSHException:
173 log.error('invalid dsa key or passphrase specified')
174
175 @property
177 """Establish the SFTP connection."""
178 if not self._sftp or self._sftp.sock.closed:
179 log.debug("creating sftp connection")
180 self._sftp = paramiko.SFTPClient.from_transport(self.transport)
181 return self._sftp
182
183 @property
191
193 return paramiko.RSAKey.generate(2048)
194
196 return ' '.join([key.get_name(), key.get_base64()])
197
199 """
200 Returns paramiko.RSAKey object for an RSA key located on the remote
201 machine
202 """
203 rfile = self.remote_file(remote_filename, 'r')
204 key = paramiko.RSAKey(file_obj=rfile)
205 rfile.close()
206 return key
207
209 """
210 Same as os.makedirs - makes a new directory and automatically creates
211 all parent directories if they do not exist.
212
213 mode specifies unix permissions to apply to the new dir
214 """
215 head, tail = posixpath.split(path)
216 if not tail:
217 head, tail = posixpath.split(head)
218 if head and tail and not self.path_exists(head):
219 try:
220 self.makedirs(head, mode)
221 except OSError, e:
222
223 if e.errno != os.errno.EEXIST:
224 raise
225
226 if tail == posixpath.curdir:
227 return
228 self.mkdir(path, mode)
229
230 - def mkdir(self, path, mode=0755, ignore_failure=False):
231 """
232 Make a new directory on the remote machine
233
234 If parent is True, create all parent directories that do not exist
235
236 mode specifies unix permissions to apply to the new dir
237 """
238 try:
239 return self.sftp.mkdir(path, mode)
240 except IOError:
241 if not ignore_failure:
242 raise
243
245 """
246 Returns list of lines in a remote_file
247
248 If regex is passed only lines that contain a pattern that matches
249 regex will be returned
250
251 If matching is set to False then only lines *not* containing a pattern
252 that matches regex will be returned
253 """
254 f = self.remote_file(remote_file, 'r')
255 flines = f.readlines()
256 f.close()
257 if regex is None:
258 return flines
259 r = re.compile(regex)
260 lines = []
261 for line in flines:
262 match = r.search(line)
263 if matching and match:
264 lines.append(line)
265 elif not matching and not match:
266 lines.append(line)
267 return lines
268
270 """
271 Removes lines matching regex from remote_file
272 """
273 if regex in [None, '']:
274 log.debug('no regex supplied...returning')
275 return
276 lines = self.get_remote_file_lines(remote_file, regex, matching=False)
277 log.debug("new %s after removing regex (%s) matches:\n%s" %
278 (remote_file, regex, ''.join(lines)))
279 f = self.remote_file(remote_file)
280 f.writelines(lines)
281 f.close()
282
283 - def unlink(self, remote_file):
285
287 """
288 Returns a remote file descriptor
289 """
290 rfile = self.sftp.open(file, mode)
291 rfile.name = file
292 return rfile
293
295 """
296 Test whether a remote path exists.
297 Returns False for broken symbolic links
298 """
299 try:
300 self.stat(path)
301 return True
302 except IOError:
303 return False
304
306 """
307 Test whether a remote path exists.
308 Returns True for broken symbolic links
309 """
310 try:
311 self.lstat(path)
312 return True
313 except IOError:
314 return False
315
316 - def chown(self, uid, gid, remote_file):
323
324 - def chmod(self, mode, remote_file):
331
332 - def ls(self, path):
333 """
334 Return a list containing the names of the entries in the remote path.
335 """
336 return [os.path.join(path, f) for f in self.sftp.listdir(path)]
337
338 - def glob(self, pattern):
339 return self._glob.glob(pattern)
340
342 """
343 Return true if the remote path refers to an existing directory.
344 """
345 try:
346 s = self.stat(path)
347 except IOError:
348 return False
349 return stat.S_ISDIR(s.st_mode)
350
352 """
353 Return true if the remote path refers to an existing file.
354 """
355 try:
356 s = self.stat(path)
357 except IOError:
358 return False
359 return stat.S_ISREG(s.st_mode)
360
361 - def stat(self, path):
362 """
363 Perform a stat system call on the given remote path.
364 """
365 return self.sftp.stat(path)
366
368 """
369 Same as stat but doesn't follow symlinks
370 """
371 return self.sftp.lstat(path)
372
373 @property
375 if not self._progress_bar:
376 widgets = ['FileTransfer: ', ' ', progressbar.Percentage(), ' ',
377 progressbar.Bar(marker=progressbar.RotatingMarker()),
378 ' ', progressbar.ETA(), ' ',
379 progressbar.FileTransferSpeed()]
380 pbar = progressbar.ProgressBar(widgets=widgets,
381 maxval=1,
382 force_update=True)
383 self._progress_bar = pbar
384 return self._progress_bar
385
387 pbar = self.progress_bar
388 pbar.widgets[0] = filename
389 pbar.maxval = size
390 pbar.update(sent)
391 if pbar.finished:
392 pbar.reset()
393
395 if not isinstance(obj, (list, tuple)):
396 return [obj]
397 return obj
398
399 - def get(self, remotepaths, localpath=''):
400 """
401 Copies one or more files from the remote host to the local host.
402 """
403 remotepaths = self._make_list(remotepaths)
404 localpath = localpath or os.getcwd()
405 globs = []
406 noglobs = []
407 for rpath in remotepaths:
408 if glob.has_magic(rpath):
409 globs.append(rpath)
410 else:
411 noglobs.append(rpath)
412 globresults = [self.glob(g) for g in globs]
413 remotepaths = noglobs
414 for globresult in globresults:
415 remotepaths.extend(globresult)
416 recursive = False
417 for rpath in remotepaths:
418 if not self.path_exists(rpath):
419 raise exception.BaseException(
420 "Remote file or directory does not exist: %s" % rpath)
421 for rpath in remotepaths:
422 if self.isdir(rpath):
423 recursive = True
424 break
425 self.scp.get(remotepaths, localpath, recursive=recursive)
426
427 - def put(self, localpaths, remotepath='.'):
428 """
429 Copies one or more files from the local host to the remote host.
430 """
431 localpaths = self._make_list(localpaths)
432 recursive = False
433 for lpath in localpaths:
434 if os.path.isdir(lpath):
435 recursive = True
436 break
437 self.scp.put(localpaths, remote_path=remotepath, recursive=recursive)
438
440 """
441 Executes a remote command so that it continues running even after this
442 SSH connection closes. The remote process will be put into the
443 background via nohup. Does not return output or check for non-zero exit
444 status.
445 """
446 return self.execute(command, detach=True,
447 source_profile=source_profile)
448
450 return self.__last_status
451
452 - def get_status(self, command, source_profile=False):
453 """
454 Execute a remote command and return the exit status
455 """
456 channel = self.transport.open_session()
457 if source_profile:
458 command = "source /etc/profile && %s" % command
459 channel.exec_command(command)
460 self.__last_status = channel.recv_exit_status()
461 return self.__last_status
462
463 - def _get_output(self, channel, silent=True, only_printable=False):
464 """
465 Returns the stdout/stderr output from a paramiko channel as a list of
466 strings (non-interactive only)
467 """
468
469 stdout = channel.makefile('rb', -1)
470 stderr = channel.makefile_stderr('rb', -1)
471 if silent:
472 output = stdout.readlines() + stderr.readlines()
473 else:
474 output = []
475 line = None
476 while line != '':
477 line = stdout.readline()
478 if only_printable:
479 line = ''.join(c for c in line if c in string.printable)
480 if line != '':
481 output.append(line)
482 print line,
483 for line in stderr.readlines():
484 output.append(line)
485 print line
486 if only_printable:
487 output = map(lambda line: ''.join(c for c in line if c in
488 string.printable), output)
489 output = map(lambda line: line.strip(), output)
490 return output
491
492 - def execute(self, command, silent=True, only_printable=False,
493 ignore_exit_status=False, log_output=True, detach=False,
494 source_profile=False):
495 """
496 Execute a remote command and return stdout/stderr
497
498 NOTE: this function blocks until the process finishes
499
500 kwargs:
501 silent - do not log output to console
502 only_printable - filter the command's output to allow only printable
503 characters
504 ignore_exit_status - don't warn about non-zero exit status
505 log_output - log output to debug file
506 detach - detach the remote process so that it continues to run even
507 after the SSH connection closes (does NOT return output or
508 check for non-zero exit status if detach=True)
509 source_profile - if True prefix the command with "source /etc/profile"
510 returns List of output lines
511 """
512 channel = self.transport.open_session()
513 if detach:
514 command = "nohup %s &" % command
515 if source_profile:
516 command = "source /etc/profile && %s" % command
517 channel.exec_command(command)
518 channel.close()
519 self.__last_status = None
520 return
521 if source_profile:
522 command = "source /etc/profile && %s" % command
523 channel.exec_command(command)
524 output = self._get_output(channel, silent=silent,
525 only_printable=only_printable)
526 exit_status = channel.recv_exit_status()
527 self.__last_status = exit_status
528 if exit_status != 0:
529 msg = "command '%s' failed with status %d" % (command, exit_status)
530 if not ignore_exit_status:
531 log.error(msg)
532 else:
533 log.debug(msg)
534 if log_output:
535 for line in output:
536 log.debug(line.strip())
537 return output
538
547
549 """
550 Checks that all commands in the progs list exist on the remote system.
551 Returns True if all commands exist and raises exception.CommandNotFound
552 if not.
553 """
554 for prog in progs:
555 if not self.which(prog):
556 raise exception.RemoteCommandNotFound(prog)
557 return True
558
560 return self.execute('which %s' % prog, ignore_exit_status=True)
561
563 """Returns the PATH environment variable on the remote machine"""
564 return self.get_env()['PATH']
565
567 """Returns the remote machine's environment as a dictionary"""
568 env = {}
569 for line in self.execute('env'):
570 key, val = line.split('=', 1)
571 env[key] = val
572 return env
573
575 """Closes the connection and cleans up."""
576 if self._sftp:
577 self._sftp.close()
578 if self._transport:
579 self._transport.close()
580
582 chan = self.transport.open_session()
583 chan.get_pty(term, cols, lines)
584 chan.invoke_shell()
585 return chan
586
591
593 """
594 Reconnect, if necessary, to host as user
595 """
596 if not self.is_active() or user and self.get_current_user() != user:
597 self.connect(username=user)
598 else:
599 user = user or self._username
600 log.debug("already connected as user %s" % user)
601
618
620 import select
621
622 oldtty = termios.tcgetattr(sys.stdin)
623 try:
624 tty.setraw(sys.stdin.fileno())
625 tty.setcbreak(sys.stdin.fileno())
626 chan.settimeout(0.0)
627
628
629 chan.send('eval $(resize)\n')
630
631 while True:
632 r, w, e = select.select([chan, sys.stdin], [], [])
633 if chan in r:
634 try:
635 x = chan.recv(1024)
636 if len(x) == 0:
637 print '\r\n*** EOF\r\n',
638 break
639 sys.stdout.write(x)
640 sys.stdout.flush()
641 except socket.timeout:
642 pass
643 if sys.stdin in r:
644
645 x = os.read(sys.stdin.fileno(), 1)
646 if len(x) == 0:
647 break
648 chan.send(x)
649 finally:
650 termios.tcsetattr(sys.stdin, termios.TCSADRAIN, oldtty)
651
652
654 import threading
655
656 sys.stdout.write("Line-buffered terminal emulation. "
657 "Press F6 or ^Z to send EOF.\r\n\r\n")
658
659 def writeall(sock):
660 while True:
661 data = sock.recv(256)
662 if not data:
663 sys.stdout.write('\r\n*** EOF ***\r\n\r\n')
664 sys.stdout.flush()
665 break
666 sys.stdout.write(data)
667 sys.stdout.flush()
668
669 writer = threading.Thread(target=writeall, args=(chan,))
670 writer.start()
671
672
673 chan.send('eval $(resize)\n')
674
675 try:
676 while True:
677 d = sys.stdin.read(1)
678 if not d:
679 break
680 chan.send(d)
681 except EOFError:
682
683 pass
684
686 """Attempt to clean up if not explicitly closed."""
687 log.debug('__del__ called')
688 self.close()
689
690
691
692 Connection = SSHClient
696
698 self.ssh = ssh_client
699
700 - def glob(self, pathname):
701 return list(self.iglob(pathname))
702
703 - def iglob(self, pathname):
704 """
705 Return an iterator which yields the paths matching a pathname pattern.
706 The pattern may contain simple shell-style wildcards a la fnmatch.
707 """
708 if not glob.has_magic(pathname):
709 if self.ssh.lpath_exists(pathname):
710 yield pathname
711 return
712 dirname, basename = posixpath.split(pathname)
713 if not dirname:
714 for name in self.glob1(posixpath.curdir, basename):
715 yield name
716 return
717 if glob.has_magic(dirname):
718 dirs = self.iglob(dirname)
719 else:
720 dirs = [dirname]
721 if glob.has_magic(basename):
722 glob_in_dir = self.glob1
723 else:
724 glob_in_dir = self.glob0
725 for dirname in dirs:
726 for name in glob_in_dir(dirname, basename):
727 yield posixpath.join(dirname, name)
728
729 - def glob0(self, dirname, basename):
730 if basename == '':
731
732
733 if self.ssh.isdir(dirname):
734 return [basename]
735 else:
736 if self.ssh.lexists(posixpath.join(dirname, basename)):
737 return [basename]
738 return []
739
740 - def glob1(self, dirname, pattern):
741 if not dirname:
742 dirname = posixpath.curdir
743 if isinstance(pattern, unicode) and not isinstance(dirname, unicode):
744
745
746 dirname = unicode(dirname, 'UTF-8')
747 try:
748 names = [os.path.basename(n) for n in self.ssh.ls(dirname)]
749 except os.error:
750 return []
751 if pattern[0] != '.':
752 names = filter(lambda x: x[0] != '.', names)
753 return fnmatch.filter(names, pattern)
754
755
756 RSA_OID = univ.ObjectIdentifier('1.2.840.113549.1.1.1')
757 RSA_PARAMS = ['n', 'e', 'd', 'p', 'q', 'dp', 'dq', 'invq']
761 return char.join(
762 string[i:i + every] for i in xrange(0, len(string), every))
763
766 seq = univ.Sequence()
767 for i in range(len(vals)):
768 seq.setComponentByPosition(i, vals[i])
769 return seq
770
773 oid = ASN1Sequence(RSA_OID, univ.Null())
774 key = univ.Sequence().setComponentByPosition(0, univ.Integer(0))
775 for i in range(len(RSA_PARAMS)):
776 key.setComponentByPosition(i + 1, univ.Integer(params[RSA_PARAMS[i]]))
777 octkey = encoder.encode(key)
778 seq = ASN1Sequence(univ.Integer(0), oid, univ.OctetString(octkey))
779 return encoder.encode(seq)
780
783 """
784 Returns the fingerprint of a private RSA key as a 59-character string (40
785 characters separated every 2 characters by a ':'). The fingerprint is
786 computed using a SHA1 digest of the DER encoded RSA private key.
787 """
788 try:
789 k = RSAKey.from_private_key_file(key_location)
790 except paramiko.SSHException:
791 raise exception.SSHError("Invalid RSA private key file: %s" %
792 key_location)
793 params = dict(invq=util.mod_inverse(k.q, k.p), dp=k.d % (k.p - 1),
794 dq=k.d % (k.q - 1), d=k.d, n=k.n, p=k.p, q=k.q, e=k.e)
795 assert len(params) == 8
796
797 pkcs8der = export_rsa_to_pkcs8(params)
798 sha1digest = hashlib.sha1(pkcs8der).hexdigest()
799 return insert_char_every_n_chars(sha1digest, ':', 2)
800
803 try:
804 k = RSAKey.from_private_key_file(pubkey_location)
805 except paramiko.SSHException:
806 raise exception.SSHError("Invalid RSA private key file: %s" %
807 pubkey_location)
808 md5digest = hashlib.md5(str(k)).hexdigest()
809 return insert_char_every_n_chars(md5digest, ':', 2)
810
827
842