Package starcluster :: Module ssh
[hide private]
[frames] | no frames]

Source Code for Module starcluster.ssh

  1  """ 
  2  ssh.py 
  3  Friendly Python SSH2 interface. 
  4  From http://commandline.org.uk/code/ 
  5  License: LGPL 
  6  modified by justin riley (justin.t.riley@gmail.com) 
  7  """ 
  8   
  9  import os 
 10  import re 
 11  import sys 
 12  import stat 
 13  import glob 
 14  import string 
 15  import socket 
 16  import fnmatch 
 17  import paramiko 
 18  import posixpath 
 19   
 20  # windows does not have termios... 
 21  try: 
 22      import termios 
 23      import tty 
 24      HAS_TERMIOS = True 
 25  except ImportError: 
 26      HAS_TERMIOS = False 
 27   
 28  from starcluster import scp 
 29  from starcluster import exception 
 30  from starcluster import progressbar 
 31  from starcluster.logger import log 
32 33 34 -class SSHClient(object):
35 """ 36 Establishes an SSH connection to a remote host using either password or 37 private key authentication. Once established, this object allows executing 38 commands, copying files to/from the remote host, various file querying 39 similar to os.path.*, and much more. 40 """ 41
42 - def __init__(self, 43 host, 44 username=None, 45 password=None, 46 private_key=None, 47 private_key_pass=None, 48 port=22, 49 timeout=30):
50 self._host = host 51 self._port = 22 52 self._pkey = None 53 self._username = username or os.environ['LOGNAME'] 54 self._password = password 55 self._timeout = timeout 56 self._sftp = None 57 self._scp = None 58 self._transport = None 59 self._progress_bar = None 60 if private_key: 61 self._pkey = self.load_private_key(private_key, private_key_pass) 62 elif not password: 63 raise exception.SSHNoCredentialsError() 64 self._glob = SSHGlob(self) 65 self.__last_status = None
66
67 - def load_private_key(self, private_key, private_key_pass=None):
68 # Use Private Key. 69 log.debug('loading private key %s' % private_key) 70 if private_key.endswith('rsa') or private_key.count('rsa'): 71 pkey = self._load_rsa_key(private_key, private_key_pass) 72 elif private_key.endswith('dsa') or private_key.count('dsa'): 73 pkey = self._load_dsa_key(private_key, private_key_pass) 74 else: 75 log.debug("specified key does not end in either rsa or dsa" + \ 76 ", trying both") 77 pkey = self._load_rsa_key(private_key, private_key_pass) 78 if pkey is None: 79 pkey = self._load_dsa_key(private_key, private_key_pass) 80 return pkey
81
82 - def connect(self, host=None, username=None, password=None, 83 private_key=None, private_key_pass=None, port=22, timeout=30):
84 host = host or self._host 85 username = username or self._username 86 password = password or self._password 87 pkey = self._pkey 88 if private_key: 89 pkey = self.load_private_key(private_key, private_key_pass) 90 log.debug("connecting to host %s on port %d as user %s" % (host, port, 91 username)) 92 try: 93 sock = self._get_socket(host, port) 94 transport = paramiko.Transport(sock) 95 transport.banner_timeout = timeout 96 except socket.error: 97 raise exception.SSHConnectionError(host, port) 98 # Authenticate the transport. 99 try: 100 transport.connect(username=username, pkey=pkey, password=password) 101 except paramiko.AuthenticationException: 102 raise exception.SSHAuthException(username, host) 103 except paramiko.SSHException, e: 104 msg = e.args[0] 105 raise exception.SSHError(msg) 106 except socket.error: 107 raise exception.SSHConnectionError(host, port) 108 except EOFError: 109 raise exception.SSHConnectionError(host, port) 110 except Exception, e: 111 raise exception.SSHError(str(e)) 112 self.close() 113 self._transport = transport 114 return self
115 116 @property
117 - def transport(self):
118 """ 119 This property attempts to return an active SSH transport 120 """ 121 if not self._transport or not self._transport.is_active(): 122 self.connect(self._host, self._username, self._password, 123 port=self._port, timeout=self._timeout) 124 return self._transport
125
126 - def get_server_public_key(self):
127 return self.transport.get_remote_server_key()
128
129 - def is_active(self):
130 if self._transport: 131 return self._transport.is_active() 132 return False
133
134 - def _get_socket(self, hostname, port):
135 for (family, socktype, proto, canonname, sockaddr) in \ 136 socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, 137 socket.SOCK_STREAM): 138 if socktype == socket.SOCK_STREAM: 139 af = family 140 break 141 else: 142 raise exception.SSHError( 143 'No suitable address family for %s' % hostname) 144 sock = socket.socket(af, socket.SOCK_STREAM) 145 sock.settimeout(self._timeout) 146 sock.connect((hostname, port)) 147 return sock
148
149 - def _load_rsa_key(self, private_key, private_key_pass=None):
150 private_key_file = os.path.expanduser(private_key) 151 try: 152 rsa_key = paramiko.RSAKey.from_private_key_file(private_key_file, 153 private_key_pass) 154 log.debug("Using private key %s (rsa)" % private_key) 155 return rsa_key 156 except paramiko.SSHException: 157 log.error('invalid rsa key or passphrase specified')
158
159 - def _load_dsa_key(self, private_key, private_key_pass=None):
160 private_key_file = os.path.expanduser(private_key) 161 try: 162 dsa_key = paramiko.DSSKey.from_private_key_file(private_key_file, 163 private_key_pass) 164 log.info("Using private key %s (dsa)" % private_key) 165 return dsa_key 166 except paramiko.SSHException: 167 log.error('invalid dsa key or passphrase specified')
168 169 @property
170 - def sftp(self):
171 """Establish the SFTP connection.""" 172 if not self._sftp or self._sftp.sock.closed: 173 log.debug("creating sftp connection") 174 self._sftp = paramiko.SFTPClient.from_transport(self.transport) 175 return self._sftp
176 177 @property
178 - def scp(self):
179 """Initialize the SCP client.""" 180 if not self._scp: 181 log.debug("creating scp connection") 182 self._scp = scp.SCPClient(self.transport, 183 progress=self._file_transfer_progress) 184 return self._scp
185
186 - def generate_rsa_key(self):
187 return paramiko.RSAKey.generate(2048)
188
189 - def get_public_key(self, key):
190 return ' '.join([key.get_name(), key.get_base64()])
191
192 - def load_remote_rsa_key(self, remote_filename):
193 """ 194 Returns paramiko.RSAKey object for an RSA key located on the remote 195 machine 196 """ 197 rfile = self.remote_file(remote_filename, 'r') 198 key = paramiko.RSAKey(file_obj=rfile) 199 rfile.close() 200 return key
201
202 - def makedirs(self, path, mode=0755):
203 """ 204 Same as os.makedirs - makes a new directory and automatically creates 205 all parent directories if they do not exist. 206 207 mode specifies unix permissions to apply to the new dir 208 """ 209 head, tail = posixpath.split(path) 210 if not tail: 211 head, tail = posixpath.split(head) 212 if head and tail and not self.path_exists(head): 213 try: 214 self.makedirs(head, mode) 215 except OSError, e: 216 # be happy if someone already created the path 217 if e.errno != os.errno.EEXIST: 218 raise 219 # xxx/newdir/. exists if xxx/newdir exists 220 if tail == posixpath.curdir: 221 return 222 self.mkdir(path, mode)
223
224 - def mkdir(self, path, mode=0755, ignore_failure=False):
225 """ 226 Make a new directory on the remote machine 227 228 If parent is True, create all parent directories that do not exist 229 230 mode specifies unix permissions to apply to the new dir 231 """ 232 try: 233 return self.sftp.mkdir(path, mode) 234 except IOError: 235 if not ignore_failure: 236 raise
237
238 - def get_remote_file_lines(self, remote_file, regex=None, matching=True):
239 """ 240 Returns list of lines in a remote_file 241 242 If regex is passed only lines that contain a pattern that matches 243 regex will be returned 244 245 If matching is set to False then only lines *not* containing a pattern 246 that matches regex will be returned 247 """ 248 f = self.remote_file(remote_file, 'r') 249 flines = f.readlines() 250 f.close() 251 if regex is None: 252 return flines 253 r = re.compile(regex) 254 lines = [] 255 for line in flines: 256 match = r.search(line) 257 if matching and match: 258 lines.append(line) 259 elif not matching and not match: 260 lines.append(line) 261 return lines
262
263 - def remove_lines_from_file(self, remote_file, regex):
264 """ 265 Removes lines matching regex from remote_file 266 """ 267 if regex in [None, '']: 268 log.debug('no regex supplied...returning') 269 return 270 lines = self.get_remote_file_lines(remote_file, regex, matching=False) 271 log.debug("new %s after removing regex (%s) matches:\n%s" % \ 272 (remote_file, regex, ''.join(lines))) 273 f = self.remote_file(remote_file) 274 f.writelines(lines) 275 f.close()
276 279
280 - def remote_file(self, file, mode='w'):
281 """ 282 Returns a remote file descriptor 283 """ 284 rfile = self.sftp.open(file, mode) 285 rfile.name = file 286 return rfile
287
288 - def path_exists(self, path):
289 """ 290 Test whether a remote path exists. 291 Returns False for broken symbolic links 292 """ 293 try: 294 self.stat(path) 295 return True 296 except IOError: 297 return False
298
299 - def lpath_exists(self, path):
300 """ 301 Test whether a remote path exists. 302 Returns True for broken symbolic links 303 """ 304 try: 305 self.lstat(path) 306 return True 307 except IOError: 308 return False
309
310 - def chown(self, uid, gid, remote_file):
311 """ 312 Apply permissions (mode) to remote_file 313 """ 314 f = self.remote_file(remote_file, 'r') 315 f.chown(uid, gid, remote_file) 316 f.close()
317
318 - def chmod(self, mode, remote_file):
319 """ 320 Apply permissions (mode) to remote_file 321 """ 322 f = self.remote_file(remote_file, 'r') 323 f.chmod(mode) 324 f.close()
325
326 - def ls(self, path):
327 """ 328 Return a list containing the names of the entries in the remote path. 329 """ 330 return [os.path.join(path, f) for f in self.sftp.listdir(path)]
331
332 - def glob(self, pattern):
333 return self._glob.glob(pattern)
334
335 - def isdir(self, path):
336 """ 337 Return true if the remote path refers to an existing directory. 338 """ 339 try: 340 s = self.stat(path) 341 except IOError: 342 return False 343 return stat.S_ISDIR(s.st_mode)
344
345 - def isfile(self, path):
346 """ 347 Return true if the remote path refers to an existing file. 348 """ 349 try: 350 s = self.stat(path) 351 except IOError: 352 return False 353 return stat.S_ISREG(s.st_mode)
354
355 - def stat(self, path):
356 """ 357 Perform a stat system call on the given remote path. 358 """ 359 return self.sftp.stat(path)
360
361 - def lstat(self, path):
362 """ 363 Same as stat but doesn't follow symlinks 364 """ 365 return self.sftp.lstat(path)
366 367 @property
368 - def progress_bar(self):
369 if not self._progress_bar: 370 widgets = ['FileTransfer: ', ' ', progressbar.Percentage(), ' ', 371 progressbar.Bar(marker=progressbar.RotatingMarker()), 372 ' ', progressbar.ETA(), ' ', 373 progressbar.FileTransferSpeed()] 374 pbar = progressbar.ProgressBar(widgets=widgets, 375 maxval=1, 376 force_update=True) 377 self._progress_bar = pbar 378 return self._progress_bar
379
380 - def _file_transfer_progress(self, filename, size, sent):
381 pbar = self.progress_bar 382 pbar.widgets[0] = filename 383 pbar.maxval = size 384 pbar.update(sent) 385 if pbar.finished: 386 pbar.reset()
387
388 - def _make_list(self, obj):
389 if not isinstance(obj, (list, tuple)): 390 return [obj] 391 return obj
392
393 - def get(self, remotepaths, localpath=''):
394 """ 395 Copies one or more files from the remote host to the local host. 396 """ 397 remotepaths = self._make_list(remotepaths) 398 localpath = localpath or os.getcwd() 399 globs = [] 400 noglobs = [] 401 for rpath in remotepaths: 402 if glob.has_magic(rpath): 403 globs.append(rpath) 404 else: 405 noglobs.append(rpath) 406 globresults = [self.glob(g) for g in globs] 407 remotepaths = noglobs 408 for globresult in globresults: 409 remotepaths.extend(globresult) 410 recursive = False 411 for rpath in remotepaths: 412 if not self.path_exists(rpath): 413 raise exception.BaseException( 414 "Remote file or directory does not exist: %s" % rpath) 415 for rpath in remotepaths: 416 if self.isdir(rpath): 417 recursive = True 418 break 419 self.scp.get(remotepaths, localpath, recursive=recursive)
420
421 - def put(self, localpaths, remotepath='.'):
422 """ 423 Copies one or more files from the local host to the remote host. 424 """ 425 localpaths = self._make_list(localpaths) 426 recursive = False 427 for lpath in localpaths: 428 if os.path.isdir(lpath): 429 recursive = True 430 break 431 self.scp.put(localpaths, remote_path=remotepath, recursive=recursive)
432
433 - def execute_async(self, command, source_profile=False):
434 """ 435 Executes a remote command so that it continues running even after this 436 SSH connection closes. The remote process will be put into the 437 background via nohup. Does not return output or check for non-zero exit 438 status. 439 """ 440 return self.execute(command, detach=True, 441 source_profile=source_profile)
442
443 - def get_last_status(self):
444 return self.__last_status
445
446 - def get_status(self, command, source_profile=False):
447 """ 448 Execute a remote command and return the exit status 449 """ 450 channel = self.transport.open_session() 451 if source_profile: 452 command = "source /etc/profile && %s" % command 453 channel.exec_command(command) 454 self.__last_status = channel.recv_exit_status() 455 return self.__last_status
456
457 - def _get_output(self, channel, silent=True, only_printable=False):
458 """ 459 Returns the stdout/stderr output from a paramiko channel as a list of 460 strings (non-interactive only) 461 """ 462 #stdin = channel.makefile('wb', -1) 463 stdout = channel.makefile('rb', -1) 464 stderr = channel.makefile_stderr('rb', -1) 465 if silent: 466 output = stdout.readlines() + stderr.readlines() 467 else: 468 output = [] 469 line = None 470 while line != '': 471 line = stdout.readline() 472 if only_printable: 473 line = ''.join(c for c in line if c in string.printable) 474 if line != '': 475 output.append(line) 476 print line, 477 for line in stderr.readlines(): 478 output.append(line) 479 print line 480 if only_printable: 481 output = map(lambda line: ''.join(c for c in line if c in 482 string.printable), output) 483 output = map(lambda line: line.strip(), output) 484 return output
485
486 - def execute(self, command, silent=True, only_printable=False, 487 ignore_exit_status=False, log_output=True, detach=False, 488 source_profile=False):
489 """ 490 Execute a remote command and return stdout/stderr 491 492 NOTE: this function blocks until the process finishes 493 494 kwargs: 495 silent - do not log output to console 496 only_printable - filter the command's output to allow only printable 497 characters 498 ignore_exit_status - don't warn about non-zero exit status 499 log_output - log output to debug file 500 detach - detach the remote process so that it continues to run even 501 after the SSH connection closes (does NOT return output or 502 check for non-zero exit status if detach=True) 503 source_profile - if True prefix the command with "source /etc/profile" 504 returns List of output lines 505 """ 506 channel = self.transport.open_session() 507 if detach: 508 command = "nohup %s &" % command 509 if source_profile: 510 command = "source /etc/profile && %s" % command 511 channel.exec_command(command) 512 channel.close() 513 self.__last_status = None 514 return 515 if source_profile: 516 command = "source /etc/profile && %s" % command 517 channel.exec_command(command) 518 output = self._get_output(channel, silent=silent, 519 only_printable=only_printable) 520 exit_status = channel.recv_exit_status() 521 self.__last_status = exit_status 522 if exit_status != 0: 523 msg = "command '%s' failed with status %d" % (command, exit_status) 524 if not ignore_exit_status: 525 log.error(msg) 526 else: 527 log.debug(msg) 528 if log_output: 529 for line in output: 530 log.debug(line.strip()) 531 return output
532
533 - def has_required(self, progs):
534 """ 535 Same as check_required but returns False if not all commands exist 536 """ 537 try: 538 return self.check_required(progs) 539 except exception.RemoteCommandNotFound: 540 return False
541
542 - def check_required(self, progs):
543 """ 544 Checks that all commands in the progs list exist on the remote system. 545 Returns True if all commands exist and raises exception.CommandNotFound 546 if not. 547 """ 548 for prog in progs: 549 if not self.which(prog): 550 raise exception.RemoteCommandNotFound(prog) 551 return True
552
553 - def which(self, prog):
554 return self.execute('which %s' % prog, ignore_exit_status=True)
555
556 - def get_path(self):
557 """Returns the PATH environment variable on the remote machine""" 558 return self.get_env()['PATH']
559
560 - def get_env(self):
561 """Returns the remote machine's environment as a dictionary""" 562 env = {} 563 for line in self.execute('env'): 564 key, val = line.split('=', 1) 565 env[key] = val 566 return env
567
568 - def close(self):
569 """Closes the connection and cleans up.""" 570 if self._sftp: 571 self._sftp.close() 572 if self._transport: 573 self._transport.close()
574
575 - def _invoke_shell(self, term='screen', cols=80, lines=24):
576 chan = self.transport.open_session() 577 chan.get_pty(term, cols, lines) 578 chan.invoke_shell() 579 return chan
580
581 - def get_current_user(self):
582 if not self.is_active(): 583 return 584 return self.transport.get_username()
585
586 - def switch_user(self, user):
587 """ 588 Reconnect, if necessary, to host as user 589 """ 590 if not self.is_active() or user and self.get_current_user() != user: 591 self.connect(username=user) 592 else: 593 user = user or self._username 594 log.debug("already connected as user %s" % user)
595
596 - def interactive_shell(self, user='root'):
597 orig_user = self.get_current_user() 598 self.switch_user(user) 599 try: 600 chan = self._invoke_shell() 601 log.info('Starting interactive shell...') 602 if HAS_TERMIOS: 603 self._posix_shell(chan) 604 else: 605 self._windows_shell(chan) 606 chan.close() 607 except Exception, e: 608 import traceback 609 print '*** Caught exception: %s: %s' % (e.__class__, e) 610 traceback.print_exc() 611 self.switch_user(orig_user)
612
613 - def _posix_shell(self, chan):
614 import select 615 616 oldtty = termios.tcgetattr(sys.stdin) 617 try: 618 tty.setraw(sys.stdin.fileno()) 619 tty.setcbreak(sys.stdin.fileno()) 620 chan.settimeout(0.0) 621 622 # needs to be sent to give vim correct size FIX 623 chan.send('eval $(resize)\n') 624 625 while True: 626 r, w, e = select.select([chan, sys.stdin], [], []) 627 if chan in r: 628 try: 629 x = chan.recv(1024) 630 if len(x) == 0: 631 print '\r\n*** EOF\r\n', 632 break 633 sys.stdout.write(x) 634 sys.stdout.flush() 635 except socket.timeout: 636 pass 637 if sys.stdin in r: 638 # fixes up arrow problem 639 x = os.read(sys.stdin.fileno(), 1) 640 if len(x) == 0: 641 break 642 chan.send(x) 643 finally: 644 termios.tcsetattr(sys.stdin, termios.TCSADRAIN, oldtty)
645 646 # thanks to Mike Looijmans for this code
647 - def _windows_shell(self, chan):
648 import threading 649 650 sys.stdout.write("Line-buffered terminal emulation. " + \ 651 "Press F6 or ^Z to send EOF.\r\n\r\n") 652 653 def writeall(sock): 654 while True: 655 data = sock.recv(256) 656 if not data: 657 sys.stdout.write('\r\n*** EOF ***\r\n\r\n') 658 sys.stdout.flush() 659 break 660 sys.stdout.write(data) 661 sys.stdout.flush()
662 663 writer = threading.Thread(target=writeall, args=(chan,)) 664 writer.start() 665 666 # needs to be sent to give vim correct size FIX 667 chan.send('eval $(resize)\n') 668 669 try: 670 while True: 671 d = sys.stdin.read(1) 672 if not d: 673 break 674 chan.send(d) 675 except EOFError: 676 # user hit ^Z or F6 677 pass
678
679 - def __del__(self):
680 """Attempt to clean up if not explicitly closed.""" 681 log.debug('__del__ called') 682 self.close()
683 684 685 # for backwards compatibility 686 Connection = SSHClient
687 688 689 -class SSHGlob(object):
690
691 - def __init__(self, ssh_client):
692 self.ssh = ssh_client
693
694 - def glob(self, pathname):
695 return list(self.iglob(pathname))
696
697 - def iglob(self, pathname):
698 """ 699 Return an iterator which yields the paths matching a pathname pattern. 700 The pattern may contain simple shell-style wildcards a la fnmatch. 701 """ 702 if not glob.has_magic(pathname): 703 if self.ssh.lpath_exists(pathname): 704 yield pathname 705 return 706 dirname, basename = posixpath.split(pathname) 707 if not dirname: 708 for name in self.glob1(posixpath.curdir, basename): 709 yield name 710 return 711 if glob.has_magic(dirname): 712 dirs = self.iglob(dirname) 713 else: 714 dirs = [dirname] 715 if glob.has_magic(basename): 716 glob_in_dir = self.glob1 717 else: 718 glob_in_dir = self.glob0 719 for dirname in dirs: 720 for name in glob_in_dir(dirname, basename): 721 yield posixpath.join(dirname, name)
722
723 - def glob0(self, dirname, basename):
724 if basename == '': 725 # `os.path.split()` returns an empty basename for paths ending with 726 # a directory separator. 'q*x/' should match only directories. 727 if self.ssh.isdir(dirname): 728 return [basename] 729 else: 730 if self.ssh.lexists(posixpath.join(dirname, basename)): 731 return [basename] 732 return []
733
734 - def glob1(self, dirname, pattern):
735 if not dirname: 736 dirname = posixpath.curdir 737 if isinstance(pattern, unicode) and not isinstance(dirname, unicode): 738 #encoding = sys.getfilesystemencoding() or sys.getdefaultencoding() 739 #dirname = unicode(dirname, encoding) 740 dirname = unicode(dirname, 'UTF-8') 741 try: 742 names = [os.path.basename(n) for n in self.ssh.ls(dirname)] 743 except os.error: 744 return [] 745 if pattern[0] != '.': 746 names = filter(lambda x: x[0] != '.', names) 747 return fnmatch.filter(names, pattern)
748
749 750 -def main():
751 """Little test when called directly.""" 752 # Set these to your own details. 753 myssh = SSHClient('somehost.domain.com') 754 print myssh.execute('hostname') 755 #myssh.put('ssh.py') 756 myssh.close()
757 758 if __name__ == "__main__": 759 main() 760