From 141b96bb312d85753de2189227941512bd42f33e Mon Sep 17 00:00:00 2001 From: Aanand Prasad Date: Fri, 8 Apr 2016 17:46:13 +0100 Subject: [PATCH] Abort operations if their dependencies fail Signed-off-by: Aanand Prasad --- compose/parallel.py | 102 +++++++++++++++++++++--------------- tests/unit/parallel_test.py | 73 ++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 43 deletions(-) create mode 100644 tests/unit/parallel_test.py diff --git a/compose/parallel.py b/compose/parallel.py index 79699236d..745d46351 100644 --- a/compose/parallel.py +++ b/compose/parallel.py @@ -32,7 +32,7 @@ def parallel_execute(objects, func, get_name, msg, get_deps=None): for obj in objects: writer.initialize(get_name(obj)) - q = setup_queue(objects, func, get_deps, get_name) + q = setup_queue(objects, func, get_deps) done = 0 errors = {} @@ -54,6 +54,8 @@ def parallel_execute(objects, func, get_name, msg, get_deps=None): elif isinstance(exception, APIError): errors[get_name(obj)] = exception.explanation writer.write(get_name(obj), 'error') + elif isinstance(exception, UpstreamError): + writer.write(get_name(obj), 'error') else: errors[get_name(obj)] = exception error_to_reraise = exception @@ -72,60 +74,74 @@ def _no_deps(x): return [] -def setup_queue(objects, func, get_deps, get_name): +def setup_queue(objects, func, get_deps): if get_deps is None: get_deps = _no_deps results = Queue() output = Queue() - def consumer(): - started = set() # objects being processed - finished = set() # objects which have been processed - - def ready(obj): - """ - Returns true if obj is ready to be processed: - - all dependencies have been processed - - obj is not already being processed - """ - return obj not in started and all( - dep not in objects or dep in finished - for dep in get_deps(obj) - ) - - while len(finished) < len(objects): - for obj in filter(ready, objects): - log.debug('Starting producer thread for {}'.format(obj)) - t = Thread(target=producer, args=(obj,)) - t.daemon = True - t.start() - started.add(obj) - - try: - event = results.get(timeout=1) - except Empty: - continue - - obj = event[0] - log.debug('Finished processing: {}'.format(obj)) - finished.add(obj) - output.put(event) - - def producer(obj): - try: - result = func(obj) - results.put((obj, result, None)) - except Exception as e: - results.put((obj, None, e)) - - t = Thread(target=consumer) + t = Thread(target=queue_consumer, args=(objects, func, get_deps, results, output)) t.daemon = True t.start() return output +def queue_producer(obj, func, results): + try: + result = func(obj) + results.put((obj, result, None)) + except Exception as e: + results.put((obj, None, e)) + + +def queue_consumer(objects, func, get_deps, results, output): + started = set() # objects being processed + finished = set() # objects which have been processed + failed = set() # objects which either failed or whose dependencies failed + + while len(finished) + len(failed) < len(objects): + pending = set(objects) - started - finished - failed + log.debug('Pending: {}'.format(pending)) + + for obj in pending: + deps = get_deps(obj) + + if any(dep in failed for dep in deps): + log.debug('{} has upstream errors - not processing'.format(obj)) + output.put((obj, None, UpstreamError())) + failed.add(obj) + elif all( + dep not in objects or dep in finished + for dep in deps + ): + log.debug('Starting producer thread for {}'.format(obj)) + t = Thread(target=queue_producer, args=(obj, func, results)) + t.daemon = True + t.start() + started.add(obj) + + try: + event = results.get(timeout=1) + except Empty: + continue + + obj, _, exception = event + if exception is None: + log.debug('Finished processing: {}'.format(obj)) + finished.add(obj) + else: + log.debug('Failed: {}'.format(obj)) + failed.add(obj) + + output.put(event) + + +class UpstreamError(Exception): + pass + + class ParallelStreamWriter(object): """Write out messages for operations happening in parallel. diff --git a/tests/unit/parallel_test.py b/tests/unit/parallel_test.py new file mode 100644 index 000000000..6be560152 --- /dev/null +++ b/tests/unit/parallel_test.py @@ -0,0 +1,73 @@ +from __future__ import absolute_import +from __future__ import unicode_literals + +import six +from docker.errors import APIError + +from compose.parallel import parallel_execute + + +web = 'web' +db = 'db' +data_volume = 'data_volume' +cache = 'cache' + +objects = [web, db, data_volume, cache] + +deps = { + web: [db, cache], + db: [data_volume], + data_volume: [], + cache: [], +} + + +def test_parallel_execute(): + results = parallel_execute( + objects=[1, 2, 3, 4, 5], + func=lambda x: x * 2, + get_name=six.text_type, + msg="Doubling", + ) + + assert sorted(results) == [2, 4, 6, 8, 10] + + +def test_parallel_execute_with_deps(): + log = [] + + def process(x): + log.append(x) + + parallel_execute( + objects=objects, + func=process, + get_name=lambda obj: obj, + msg="Processing", + get_deps=lambda obj: deps[obj], + ) + + assert sorted(log) == sorted(objects) + + assert log.index(data_volume) < log.index(db) + assert log.index(db) < log.index(web) + assert log.index(cache) < log.index(web) + + +def test_parallel_execute_with_upstream_errors(): + log = [] + + def process(x): + if x is data_volume: + raise APIError(None, None, "Something went wrong") + log.append(x) + + parallel_execute( + objects=objects, + func=process, + get_name=lambda obj: obj, + msg="Processing", + get_deps=lambda obj: deps[obj], + ) + + assert log == [cache]