diff --git a/compose/parallel.py b/compose/parallel.py index e360ca357..63417dcb0 100644 --- a/compose/parallel.py +++ b/compose/parallel.py @@ -17,6 +17,8 @@ from compose.utils import get_output_stream log = logging.getLogger(__name__) +STOP = object() + def parallel_execute(objects, func, get_name, msg, get_deps=None): """Runs func on objects in parallel while ensuring that func is @@ -32,7 +34,7 @@ def parallel_execute(objects, func, get_name, msg, get_deps=None): for obj in objects: writer.initialize(get_name(obj)) - events = parallel_execute_stream(objects, func, get_deps) + events = parallel_execute_iter(objects, func, get_deps) errors = {} results = [] @@ -65,12 +67,19 @@ def _no_deps(x): class State(object): + """ + Holds the state of a partially-complete parallel operation. + + state.started: objects being processed + state.finished: objects which have been processed + state.failed: objects which either failed or whose dependencies failed + """ def __init__(self, objects): self.objects = objects - self.started = set() # objects being processed - self.finished = set() # objects which have been processed - self.failed = set() # objects which either failed or whose dependencies failed + self.started = set() + self.finished = set() + self.failed = set() def is_done(self): return len(self.finished) + len(self.failed) >= len(self.objects) @@ -79,16 +88,30 @@ class State(object): return set(self.objects) - self.started - self.finished - self.failed -def parallel_execute_stream(objects, func, get_deps): +def parallel_execute_iter(objects, func, get_deps): + """ + Runs func on objects in parallel while ensuring that func is + ran on object only after it is ran on all its dependencies. + + Returns an iterator of tuples which look like: + + # if func returned normally when run on object + (object, result, None) + + # if func raised an exception when run on object + (object, None, exception) + + # if func raised an exception when run on one of object's dependencies + (object, None, UpstreamError()) + """ if get_deps is None: get_deps = _no_deps results = Queue() state = State(objects) - while not state.is_done(): - for event in feed_queue(objects, func, get_deps, results, state): - yield event + while True: + feed_queue(objects, func, get_deps, results, state) try: event = results.get(timeout=0.1) @@ -98,6 +121,9 @@ def parallel_execute_stream(objects, func, get_deps): except thread.error: raise ShutdownException() + if event is STOP: + break + obj, _, exception = event if exception is None: log.debug('Finished processing: {}'.format(obj)) @@ -109,7 +135,11 @@ def parallel_execute_stream(objects, func, get_deps): yield event -def queue_producer(obj, func, results): +def producer(obj, func, results): + """ + The entry point for a producer thread which runs func on a single object. + Places a tuple on the results queue once func has either returned or raised. + """ try: result = func(obj) results.put((obj, result, None)) @@ -118,6 +148,13 @@ def queue_producer(obj, func, results): def feed_queue(objects, func, get_deps, results, state): + """ + Starts producer threads for any objects which are ready to be processed + (i.e. they have no dependencies which haven't been successfully processed). + + Shortcuts any objects whose dependencies have failed and places an + (object, None, UpstreamError()) tuple on the results queue. + """ pending = state.pending() log.debug('Pending: {}'.format(pending)) @@ -126,18 +163,21 @@ def feed_queue(objects, func, get_deps, results, state): if any(dep in state.failed for dep in deps): log.debug('{} has upstream errors - not processing'.format(obj)) - yield (obj, None, UpstreamError()) + results.put((obj, None, UpstreamError())) state.failed.add(obj) elif all( dep not in objects or dep in state.finished for dep in deps ): log.debug('Starting producer thread for {}'.format(obj)) - t = Thread(target=queue_producer, args=(obj, func, results)) + t = Thread(target=producer, args=(obj, func, results)) t.daemon = True t.start() state.started.add(obj) + if state.is_done(): + results.put(STOP) + class UpstreamError(Exception): pass diff --git a/tests/unit/parallel_test.py b/tests/unit/parallel_test.py index 9ed1b3623..45b0db1db 100644 --- a/tests/unit/parallel_test.py +++ b/tests/unit/parallel_test.py @@ -5,7 +5,7 @@ import six from docker.errors import APIError from compose.parallel import parallel_execute -from compose.parallel import parallel_execute_stream +from compose.parallel import parallel_execute_iter from compose.parallel import UpstreamError @@ -81,7 +81,7 @@ def test_parallel_execute_with_upstream_errors(): events = [ (obj, result, type(exception)) for obj, result, exception - in parallel_execute_stream(objects, process, get_deps) + in parallel_execute_iter(objects, process, get_deps) ] assert (cache, None, type(None)) in events