Package starcluster :: Module ssh
[hide private]
[frames] | no frames]

Source Code for Module starcluster.ssh

  1  """ 
  2  ssh.py 
  3  Friendly Python SSH2 interface. 
  4  From http://commandline.org.uk/code/ 
  5  License: LGPL 
  6  modified by justin riley (justin.t.riley@gmail.com) 
  7  """ 
  8   
  9  import os 
 10  import re 
 11  import sys 
 12  import stat 
 13  import glob 
 14  import string 
 15  import socket 
 16  import fnmatch 
 17  import paramiko 
 18  import posixpath 
 19   
 20  # windows does not have termios... 
 21  try: 
 22      import termios 
 23      import tty 
 24      HAS_TERMIOS = True 
 25  except ImportError: 
 26      HAS_TERMIOS = False 
 27   
 28  from starcluster import scp 
 29  from starcluster import exception 
 30  from starcluster import progressbar 
 31  from starcluster.logger import log 
32 33 34 -class SSHClient(object):
35 """ 36 Establishes an SSH connection to a remote host using either password or 37 private key authentication. Once established, this object allows executing 38 commands, copying files to/from the remote host, various file querying 39 similar to os.path.*, and much more. 40 """ 41
42 - def __init__(self, 43 host, 44 username=None, 45 password=None, 46 private_key=None, 47 private_key_pass=None, 48 port=22, 49 timeout=30):
50 self._host = host 51 self._port = 22 52 self._pkey = None 53 self._username = username or os.environ['LOGNAME'] 54 self._password = password 55 self._timeout = timeout 56 self._sftp = None 57 self._scp = None 58 self._transport = None 59 self._progress_bar = None 60 if private_key: 61 self._pkey = self.load_private_key(private_key, private_key_pass) 62 elif not password: 63 raise exception.SSHNoCredentialsError() 64 self._glob = SSHGlob(self)
65
66 - def load_private_key(self, private_key, private_key_pass=None):
67 # Use Private Key. 68 log.debug('loading private key %s' % private_key) 69 if private_key.endswith('rsa') or private_key.count('rsa'): 70 pkey = self._load_rsa_key(private_key, private_key_pass) 71 elif private_key.endswith('dsa') or private_key.count('dsa'): 72 pkey = self._load_dsa_key(private_key, private_key_pass) 73 else: 74 log.debug("specified key does not end in either rsa or dsa" + \ 75 ", trying both") 76 pkey = self._load_rsa_key(private_key, private_key_pass) 77 if pkey is None: 78 pkey = self._load_dsa_key(private_key, private_key_pass) 79 return pkey
80
81 - def connect(self, host=None, username=None, password=None, 82 private_key=None, private_key_pass=None, port=22, timeout=30):
83 host = host or self._host 84 username = username or self._username 85 password = password or self._password 86 pkey = self._pkey 87 if private_key: 88 pkey = self.load_private_key(private_key, private_key_pass) 89 log.debug("connecting to host %s on port %d as user %s" % (host, port, 90 username)) 91 try: 92 sock = self._get_socket(host, port) 93 transport = paramiko.Transport(sock) 94 transport.banner_timeout = timeout 95 except socket.error: 96 raise exception.SSHConnectionError(host, port) 97 # Authenticate the transport. 98 try: 99 transport.connect(username=username, pkey=pkey, password=password) 100 except paramiko.AuthenticationException: 101 raise exception.SSHAuthException(username, host) 102 except paramiko.SSHException, e: 103 msg = e.args[0] 104 raise exception.SSHError(msg) 105 except socket.error: 106 raise exception.SSHConnectionError(host, port) 107 except EOFError: 108 raise exception.SSHConnectionError(host, port) 109 except Exception, e: 110 raise exception.SSHError(str(e)) 111 self.close() 112 self._transport = transport 113 return self
114 115 @property
116 - def transport(self):
117 """ 118 This property attempts to return an active SSH transport 119 """ 120 if not self._transport or not self._transport.is_active(): 121 self.connect(self._host, self._username, self._password, 122 port=self._port, timeout=self._timeout) 123 return self._transport
124
125 - def get_server_public_key(self):
126 return self.transport.get_remote_server_key()
127
128 - def is_active(self):
129 if self._transport: 130 return self._transport.is_active() 131 return False
132
133 - def _get_socket(self, hostname, port):
134 for (family, socktype, proto, canonname, sockaddr) in \ 135 socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, 136 socket.SOCK_STREAM): 137 if socktype == socket.SOCK_STREAM: 138 af = family 139 break 140 else: 141 raise exception.SSHError( 142 'No suitable address family for %s' % hostname) 143 sock = socket.socket(af, socket.SOCK_STREAM) 144 sock.settimeout(self._timeout) 145 sock.connect((hostname, port)) 146 return sock
147
148 - def _load_rsa_key(self, private_key, private_key_pass=None):
149 private_key_file = os.path.expanduser(private_key) 150 try: 151 rsa_key = paramiko.RSAKey.from_private_key_file(private_key_file, 152 private_key_pass) 153 log.debug("Using private key %s (rsa)" % private_key) 154 return rsa_key 155 except paramiko.SSHException: 156 log.error('invalid rsa key or passphrase specified')
157
158 - def _load_dsa_key(self, private_key, private_key_pass=None):
159 private_key_file = os.path.expanduser(private_key) 160 try: 161 dsa_key = paramiko.DSSKey.from_private_key_file(private_key_file, 162 private_key_pass) 163 log.info("Using private key %s (dsa)" % private_key) 164 return dsa_key 165 except paramiko.SSHException: 166 log.error('invalid dsa key or passphrase specified')
167 168 @property
169 - def sftp(self):
170 """Establish the SFTP connection.""" 171 if not self._sftp or self._sftp.sock.closed: 172 log.debug("creating sftp connection") 173 self._sftp = paramiko.SFTPClient.from_transport(self.transport) 174 return self._sftp
175 176 @property
177 - def scp(self):
178 """Initialize the SCP client.""" 179 if not self._scp: 180 log.debug("creating scp connection") 181 self._scp = scp.SCPClient(self.transport, 182 progress=self._file_transfer_progress) 183 return self._scp
184
185 - def generate_rsa_key(self):
186 return paramiko.RSAKey.generate(2048)
187
188 - def get_public_key(self, key):
189 return ' '.join([key.get_name(), key.get_base64()])
190
191 - def load_remote_rsa_key(self, remote_filename):
192 """ 193 Returns paramiko.RSAKey object for an RSA key located on the remote 194 machine 195 """ 196 rfile = self.remote_file(remote_filename, 'r') 197 key = paramiko.RSAKey(file_obj=rfile) 198 rfile.close() 199 return key
200
201 - def makedirs(self, path, mode=0755):
202 """ 203 Same as os.makedirs - makes a new directory and automatically creates 204 all parent directories if they do not exist. 205 206 mode specifies unix permissions to apply to the new dir 207 """ 208 head, tail = posixpath.split(path) 209 if not tail: 210 head, tail = posixpath.split(head) 211 if head and tail and not self.path_exists(head): 212 try: 213 self.makedirs(head, mode) 214 except OSError, e: 215 # be happy if someone already created the path 216 if e.errno != os.errno.EEXIST: 217 raise 218 # xxx/newdir/. exists if xxx/newdir exists 219 if tail == posixpath.curdir: 220 return 221 self.mkdir(path, mode)
222
223 - def mkdir(self, path, mode=0755, ignore_failure=False):
224 """ 225 Make a new directory on the remote machine 226 227 If parent is True, create all parent directories that do not exist 228 229 mode specifies unix permissions to apply to the new dir 230 """ 231 try: 232 return self.sftp.mkdir(path, mode) 233 except IOError: 234 if not ignore_failure: 235 raise
236
237 - def get_remote_file_lines(self, remote_file, regex=None, matching=True):
238 """ 239 Returns list of lines in a remote_file 240 241 If regex is passed only lines that contain a pattern that matches 242 regex will be returned 243 244 If matching is set to False then only lines *not* containing a pattern 245 that matches regex will be returned 246 """ 247 f = self.remote_file(remote_file, 'r') 248 flines = f.readlines() 249 f.close() 250 if regex is None: 251 return flines 252 r = re.compile(regex) 253 lines = [] 254 for line in flines: 255 match = r.search(line) 256 if matching and match: 257 lines.append(line) 258 elif not matching and not match: 259 lines.append(line) 260 return lines
261
262 - def remove_lines_from_file(self, remote_file, regex):
263 """ 264 Removes lines matching regex from remote_file 265 """ 266 if regex in [None, '']: 267 log.debug('no regex supplied...returning') 268 return 269 lines = self.get_remote_file_lines(remote_file, regex, matching=False) 270 log.debug("new %s after removing regex (%s) matches:\n%s" % \ 271 (remote_file, regex, ''.join(lines))) 272 f = self.remote_file(remote_file) 273 f.writelines(lines) 274 f.close()
275 278
279 - def remote_file(self, file, mode='w'):
280 """ 281 Returns a remote file descriptor 282 """ 283 rfile = self.sftp.open(file, mode) 284 rfile.name = file 285 return rfile
286
287 - def path_exists(self, path):
288 """ 289 Test whether a remote path exists. 290 Returns False for broken symbolic links 291 """ 292 try: 293 self.stat(path) 294 return True 295 except IOError: 296 return False
297
298 - def lpath_exists(self, path):
299 """ 300 Test whether a remote path exists. 301 Returns True for broken symbolic links 302 """ 303 try: 304 self.lstat(path) 305 return True 306 except IOError: 307 return False
308
309 - def chown(self, uid, gid, remote_file):
310 """ 311 Apply permissions (mode) to remote_file 312 """ 313 f = self.remote_file(remote_file, 'r') 314 f.chown(uid, gid, remote_file) 315 f.close()
316
317 - def chmod(self, mode, remote_file):
318 """ 319 Apply permissions (mode) to remote_file 320 """ 321 f = self.remote_file(remote_file, 'r') 322 f.chmod(mode) 323 f.close()
324
325 - def ls(self, path):
326 """ 327 Return a list containing the names of the entries in the remote path. 328 """ 329 return [os.path.join(path, f) for f in self.sftp.listdir(path)]
330
331 - def glob(self, pattern):
332 return self._glob.glob(pattern)
333
334 - def isdir(self, path):
335 """ 336 Return true if the remote path refers to an existing directory. 337 """ 338 try: 339 s = self.stat(path) 340 except IOError: 341 return False 342 return stat.S_ISDIR(s.st_mode)
343
344 - def isfile(self, path):
345 """ 346 Return true if the remote path refers to an existing file. 347 """ 348 try: 349 s = self.stat(path) 350 except IOError: 351 return False 352 return stat.S_ISREG(s.st_mode)
353
354 - def stat(self, path):
355 """ 356 Perform a stat system call on the given remote path. 357 """ 358 return self.sftp.stat(path)
359
360 - def lstat(self, path):
361 """ 362 Same as stat but doesn't follow symlinks 363 """ 364 return self.sftp.lstat(path)
365 366 @property
367 - def progress_bar(self):
368 if not self._progress_bar: 369 widgets = ['FileTransfer: ', ' ', progressbar.Percentage(), ' ', 370 progressbar.Bar(marker=progressbar.RotatingMarker()), 371 ' ', progressbar.ETA(), ' ', 372 progressbar.FileTransferSpeed()] 373 pbar = progressbar.ProgressBar(widgets=widgets, 374 maxval=1, 375 force_update=True) 376 self._progress_bar = pbar 377 return self._progress_bar
378
379 - def _file_transfer_progress(self, filename, size, sent):
380 pbar = self.progress_bar 381 pbar.widgets[0] = filename 382 pbar.maxval = size 383 pbar.update(sent) 384 if pbar.finished: 385 pbar.reset()
386
387 - def _make_list(self, obj):
388 if not isinstance(obj, (list, tuple)): 389 return [obj] 390 return obj
391
392 - def get(self, remotepaths, localpath=''):
393 """ 394 Copies one or more files from the remote host to the local host. 395 """ 396 remotepaths = self._make_list(remotepaths) 397 localpath = localpath or os.getcwd() 398 globs = [] 399 noglobs = [] 400 for rpath in remotepaths: 401 if glob.has_magic(rpath): 402 globs.append(rpath) 403 else: 404 noglobs.append(rpath) 405 globresults = [self.glob(g) for g in globs] 406 remotepaths = noglobs 407 for globresult in globresults: 408 remotepaths.extend(globresult) 409 recursive = False 410 for rpath in remotepaths: 411 if not self.path_exists(rpath): 412 raise exception.BaseException( 413 "Remote file or directory does not exist: %s" % rpath) 414 for rpath in remotepaths: 415 if self.isdir(rpath): 416 recursive = True 417 break 418 self.scp.get(remotepaths, localpath, recursive=recursive)
419
420 - def put(self, localpaths, remotepath='.'):
421 """ 422 Copies one or more files from the local host to the remote host. 423 """ 424 localpaths = self._make_list(localpaths) 425 recursive = False 426 for lpath in localpaths: 427 if os.path.isdir(lpath): 428 recursive = True 429 break 430 self.scp.put(localpaths, remote_path=remotepath, recursive=recursive)
431
432 - def execute_async(self, command, source_profile=False):
433 """ 434 Executes a remote command so that it continues running even after this 435 SSH connection closes. The remote process will be put into the 436 background via nohup. Does not return output or check for non-zero exit 437 status. 438 """ 439 return self.execute(command, detach=True, 440 source_profile=source_profile)
441
442 - def get_status(self, command, source_profile=False):
443 """ 444 Execute a remote command and return the exit status 445 """ 446 channel = self.transport.open_session() 447 if source_profile: 448 command = "source /etc/profile && %s" % command 449 channel.exec_command(command) 450 return channel.recv_exit_status()
451
452 - def _get_output(self, channel, silent=True, only_printable=False):
453 """ 454 Returns the stdout/stderr output from a paramiko channel as a list of 455 strings (non-interactive only) 456 """ 457 #stdin = channel.makefile('wb', -1) 458 stdout = channel.makefile('rb', -1) 459 stderr = channel.makefile_stderr('rb', -1) 460 if silent: 461 output = stdout.readlines() + stderr.readlines() 462 else: 463 output = [] 464 line = None 465 while line != '': 466 line = stdout.readline() 467 if only_printable: 468 line = ''.join(c for c in line if c in string.printable) 469 if line != '': 470 output.append(line) 471 print line, 472 for line in stderr.readlines(): 473 output.append(line) 474 print line 475 if only_printable: 476 output = map(lambda line: ''.join(c for c in line if c in 477 string.printable), output) 478 output = map(lambda line: line.strip(), output) 479 return output
480
481 - def execute(self, command, silent=True, only_printable=False, 482 ignore_exit_status=False, log_output=True, detach=False, 483 source_profile=False):
484 """ 485 Execute a remote command and return stdout/stderr 486 487 NOTE: this function blocks until the process finishes 488 489 kwargs: 490 silent - do not log output to console 491 only_printable - filter the command's output to allow only printable 492 characters 493 ignore_exit_status - don't warn about non-zero exit status 494 log_output - log output to debug file 495 detach - detach the remote process so that it continues to run even 496 after the SSH connection closes (does NOT return output or 497 check for non-zero exit status if detach=True) 498 source_profile - if True prefix the command with "source /etc/profile" 499 returns List of output lines 500 """ 501 channel = self.transport.open_session() 502 if detach: 503 command = "nohup %s &" % command 504 if source_profile: 505 command = "source /etc/profile && %s" % command 506 channel.exec_command(command) 507 channel.close() 508 return 509 if source_profile: 510 command = "source /etc/profile && %s" % command 511 channel.exec_command(command) 512 output = self._get_output(channel, silent=silent, 513 only_printable=only_printable) 514 exit_status = channel.recv_exit_status() 515 if exit_status != 0: 516 msg = "command '%s' failed with status %d" % (command, exit_status) 517 if not ignore_exit_status: 518 log.error(msg) 519 else: 520 log.debug(msg) 521 if log_output: 522 for line in output: 523 log.debug(line.strip()) 524 return output
525
526 - def has_required(self, progs):
527 """ 528 Same as check_required but returns False if not all commands exist 529 """ 530 try: 531 return self.check_required(progs) 532 except exception.RemoteCommandNotFound: 533 return False
534
535 - def check_required(self, progs):
536 """ 537 Checks that all commands in the progs list exist on the remote system. 538 Returns True if all commands exist and raises exception.CommandNotFound 539 if not. 540 """ 541 for prog in progs: 542 if not self.which(prog): 543 raise exception.RemoteCommandNotFound(prog) 544 return True
545
546 - def which(self, prog):
547 return self.execute('which %s' % prog, ignore_exit_status=True)
548
549 - def get_path(self):
550 """Returns the PATH environment variable on the remote machine""" 551 return self.get_env()['PATH']
552
553 - def get_env(self):
554 """Returns the remote machine's environment as a dictionary""" 555 env = {} 556 for line in self.execute('env'): 557 key, val = line.split('=', 1) 558 env[key] = val 559 return env
560
561 - def close(self):
562 """Closes the connection and cleans up.""" 563 if self._sftp: 564 self._sftp.close() 565 if self._transport: 566 self._transport.close()
567
568 - def _invoke_shell(self, term='screen', cols=80, lines=24):
569 chan = self.transport.open_session() 570 chan.get_pty(term, cols, lines) 571 chan.invoke_shell() 572 return chan
573
574 - def interactive_shell(self, user='root'):
575 if user and self.transport.get_username() != user: 576 self.connect(username=user) 577 try: 578 chan = self._invoke_shell() 579 log.info('Starting interactive shell...') 580 if HAS_TERMIOS: 581 self._posix_shell(chan) 582 else: 583 self._windows_shell(chan) 584 chan.close() 585 except Exception, e: 586 import traceback 587 print '*** Caught exception: %s: %s' % (e.__class__, e) 588 traceback.print_exc()
589
590 - def _posix_shell(self, chan):
591 import select 592 593 oldtty = termios.tcgetattr(sys.stdin) 594 try: 595 tty.setraw(sys.stdin.fileno()) 596 tty.setcbreak(sys.stdin.fileno()) 597 chan.settimeout(0.0) 598 599 # needs to be sent to give vim correct size FIX 600 chan.send('eval $(resize)\n') 601 602 while True: 603 r, w, e = select.select([chan, sys.stdin], [], []) 604 if chan in r: 605 try: 606 x = chan.recv(1024) 607 if len(x) == 0: 608 print '\r\n*** EOF\r\n', 609 break 610 sys.stdout.write(x) 611 sys.stdout.flush() 612 except socket.timeout: 613 pass 614 if sys.stdin in r: 615 # fixes up arrow problem 616 x = os.read(sys.stdin.fileno(), 1) 617 if len(x) == 0: 618 break 619 chan.send(x) 620 finally: 621 termios.tcsetattr(sys.stdin, termios.TCSADRAIN, oldtty)
622 623 # thanks to Mike Looijmans for this code
624 - def _windows_shell(self, chan):
625 import threading 626 627 sys.stdout.write("Line-buffered terminal emulation. " + \ 628 "Press F6 or ^Z to send EOF.\r\n\r\n") 629 630 def writeall(sock): 631 while True: 632 data = sock.recv(256) 633 if not data: 634 sys.stdout.write('\r\n*** EOF ***\r\n\r\n') 635 sys.stdout.flush() 636 break 637 sys.stdout.write(data) 638 sys.stdout.flush()
639 640 writer = threading.Thread(target=writeall, args=(chan,)) 641 writer.start() 642 643 # needs to be sent to give vim correct size FIX 644 chan.send('eval $(resize)\n') 645 646 try: 647 while True: 648 d = sys.stdin.read(1) 649 if not d: 650 break 651 chan.send(d) 652 except EOFError: 653 # user hit ^Z or F6 654 pass
655
656 - def __del__(self):
657 """Attempt to clean up if not explicitly closed.""" 658 log.debug('__del__ called') 659 self.close()
660 661 662 # for backwards compatibility 663 Connection = SSHClient
664 665 666 -class SSHGlob(object):
667
668 - def __init__(self, ssh_client):
669 self.ssh = ssh_client
670
671 - def glob(self, pathname):
672 return list(self.iglob(pathname))
673
674 - def iglob(self, pathname):
675 """ 676 Return an iterator which yields the paths matching a pathname pattern. 677 The pattern may contain simple shell-style wildcards a la fnmatch. 678 """ 679 if not glob.has_magic(pathname): 680 if self.ssh.lpath_exists(pathname): 681 yield pathname 682 return 683 dirname, basename = posixpath.split(pathname) 684 if not dirname: 685 for name in self.glob1(posixpath.curdir, basename): 686 yield name 687 return 688 if glob.has_magic(dirname): 689 dirs = self.iglob(dirname) 690 else: 691 dirs = [dirname] 692 if glob.has_magic(basename): 693 glob_in_dir = self.glob1 694 else: 695 glob_in_dir = self.glob0 696 for dirname in dirs: 697 for name in glob_in_dir(dirname, basename): 698 yield posixpath.join(dirname, name)
699
700 - def glob0(self, dirname, basename):
701 if basename == '': 702 # `os.path.split()` returns an empty basename for paths ending with 703 # a directory separator. 'q*x/' should match only directories. 704 if self.ssh.isdir(dirname): 705 return [basename] 706 else: 707 if self.ssh.lexists(posixpath.join(dirname, basename)): 708 return [basename] 709 return []
710
711 - def glob1(self, dirname, pattern):
712 if not dirname: 713 dirname = posixpath.curdir 714 if isinstance(pattern, unicode) and not isinstance(dirname, unicode): 715 #encoding = sys.getfilesystemencoding() or sys.getdefaultencoding() 716 #dirname = unicode(dirname, encoding) 717 dirname = unicode(dirname, 'UTF-8') 718 try: 719 names = [os.path.basename(n) for n in self.ssh.ls(dirname)] 720 except os.error: 721 return [] 722 if pattern[0] != '.': 723 names = filter(lambda x: x[0] != '.', names) 724 return fnmatch.filter(names, pattern)
725
726 727 -def main():
728 """Little test when called directly.""" 729 # Set these to your own details. 730 myssh = SSHClient('somehost.domain.com') 731 print myssh.execute('hostname') 732 #myssh.put('ssh.py') 733 myssh.close()
734 735 if __name__ == "__main__": 736 main() 737