Package starcluster :: Module threadpool
[hide private]
[frames] | no frames]

Source Code for Module starcluster.threadpool

  1  #!/usr/bin/env python 
  2  """ 
  3  ThreadPool module for StarCluster based on WorkerPool 
  4  """ 
  5  import time 
  6  import Queue 
  7  import thread 
  8  import traceback 
  9  import workerpool 
 10   
 11  from starcluster import exception 
 12  from starcluster import progressbar 
 13  from starcluster.logger import log 
14 15 16 -class DaemonWorker(workerpool.workers.Worker):
17 """ 18 Improved Worker that sets daemon = True by default and also handles 19 communicating exceptions to the parent pool object by adding them to 20 the parent pool's exception queue 21 """
22 - def __init__(self, *args, **kwargs):
23 super(DaemonWorker, self).__init__(*args, **kwargs) 24 self.daemon = True
25
26 - def run(self):
27 "Get jobs from the queue and perform them as they arrive." 28 while 1: 29 # Sleep until there is a job to perform. 30 job = self.jobs.get() 31 try: 32 job.run() 33 except workerpool.exceptions.TerminationNotice: 34 break 35 except Exception, e: 36 tb_msg = traceback.format_exc() 37 jid = job.jobid or str(thread.get_ident()) 38 self.jobs.store_exception([e, tb_msg, jid]) 39 finally: 40 self.jobs.task_done()
41
42 43 -def _worker_factory(parent):
44 return DaemonWorker(parent)
45
46 47 -class SimpleJob(workerpool.jobs.SimpleJob):
48 - def __init__(self, method, args=[], kwargs={}, jobid=None, 49 results_queue=None):
50 self.method = method 51 self.args = args 52 self.kwargs = kwargs 53 self.jobid = jobid 54 self.results_queue = results_queue
55
56 - def run(self):
57 if isinstance(self.args, list) or isinstance(self.args, tuple): 58 if isinstance(self.kwargs, dict): 59 r = self.method(*self.args, **self.kwargs) 60 else: 61 r = self.method(*self.args) 62 elif self.args is not None and self.args is not []: 63 if isinstance(self.kwargs, dict): 64 r = self.method(self.args, **self.kwargs) 65 else: 66 r = self.method(self.args) 67 else: 68 r = self.method() 69 if self.results_queue: 70 return self.results_queue.put(r) 71 return r
72
73 74 -class ThreadPool(workerpool.WorkerPool):
75 - def __init__(self, size=1, maxjobs=0, worker_factory=_worker_factory, 76 disable_threads=False):
77 self.disable_threads = disable_threads 78 self._exception_queue = Queue.Queue() 79 self._results_queue = Queue.Queue() 80 self._progress_bar = None 81 if self.disable_threads: 82 size = 0 83 workerpool.WorkerPool.__init__(self, size, maxjobs, worker_factory)
84 85 @property
86 - def progress_bar(self):
87 if not self._progress_bar: 88 widgets = ['', progressbar.Fraction(), ' ', 89 progressbar.Bar(marker=progressbar.RotatingMarker()), 90 ' ', progressbar.Percentage(), ' ', ' '] 91 pbar = progressbar.ProgressBar(widgets=widgets, maxval=1, 92 force_update=True) 93 self._progress_bar = pbar 94 return self._progress_bar
95
96 - def simple_job(self, method, args=[], kwargs={}, jobid=None, 97 results_queue=None):
98 results_queue = results_queue or self._results_queue 99 job = SimpleJob(method, args, kwargs, jobid, 100 results_queue=results_queue) 101 if not self.disable_threads: 102 return self.put(job) 103 else: 104 return job.run()
105
106 - def get_results(self):
107 results = [] 108 for i in range(self._results_queue.qsize()): 109 results.append(self._results_queue.get()) 110 return results
111
112 - def map(self, fn, *seq):
113 if self._results_queue.qsize() > 0: 114 self.get_results() 115 args = zip(*seq) 116 for seq in args: 117 self.simple_job(fn, seq) 118 return self.wait(numtasks=len(args))
119
120 - def store_exception(self, e):
121 self._exception_queue.put(e)
122
123 - def shutdown(self):
124 log.info("Shutting down threads...") 125 workerpool.WorkerPool.shutdown(self) 126 self.wait(numtasks=self.size())
127
128 - def wait(self, numtasks=None, return_results=True):
129 pbar = self.progress_bar.reset() 130 pbar.maxval = self.unfinished_tasks 131 if numtasks is not None: 132 pbar.maxval = max(numtasks, self.unfinished_tasks) 133 while self.unfinished_tasks != 0: 134 finished = pbar.maxval - self.unfinished_tasks 135 pbar.update(finished) 136 log.debug("unfinished_tasks = %d" % self.unfinished_tasks) 137 time.sleep(1) 138 if pbar.maxval != 0: 139 pbar.finish() 140 self.join() 141 if self._exception_queue.qsize() > 0: 142 raise exception.ThreadPoolException( 143 "An error occured in ThreadPool", self._exception_queue.queue) 144 if return_results: 145 return self.get_results()
146
147 - def __del__(self):
148 log.debug('del called in threadpool') 149 self.shutdown() 150 self.join()
151
152 153 -def get_thread_pool(size=10, worker_factory=_worker_factory, 154 disable_threads=False):
155 return ThreadPool(size=size, worker_factory=_worker_factory, 156 disable_threads=disable_threads)
157