Abort operations if their dependencies fail

Signed-off-by: Aanand Prasad <aanand.prasad@gmail.com>
This commit is contained in:
Aanand Prasad 2016-04-08 17:46:13 +01:00
parent bcdf541c8c
commit 141b96bb31
2 changed files with 132 additions and 43 deletions

View File

@ -32,7 +32,7 @@ def parallel_execute(objects, func, get_name, msg, get_deps=None):
for obj in objects:
writer.initialize(get_name(obj))
q = setup_queue(objects, func, get_deps, get_name)
q = setup_queue(objects, func, get_deps)
done = 0
errors = {}
@ -54,6 +54,8 @@ def parallel_execute(objects, func, get_name, msg, get_deps=None):
elif isinstance(exception, APIError):
errors[get_name(obj)] = exception.explanation
writer.write(get_name(obj), 'error')
elif isinstance(exception, UpstreamError):
writer.write(get_name(obj), 'error')
else:
errors[get_name(obj)] = exception
error_to_reraise = exception
@ -72,60 +74,74 @@ def _no_deps(x):
return []
def setup_queue(objects, func, get_deps, get_name):
def setup_queue(objects, func, get_deps):
if get_deps is None:
get_deps = _no_deps
results = Queue()
output = Queue()
def consumer():
started = set() # objects being processed
finished = set() # objects which have been processed
def ready(obj):
"""
Returns true if obj is ready to be processed:
- all dependencies have been processed
- obj is not already being processed
"""
return obj not in started and all(
dep not in objects or dep in finished
for dep in get_deps(obj)
)
while len(finished) < len(objects):
for obj in filter(ready, objects):
log.debug('Starting producer thread for {}'.format(obj))
t = Thread(target=producer, args=(obj,))
t.daemon = True
t.start()
started.add(obj)
try:
event = results.get(timeout=1)
except Empty:
continue
obj = event[0]
log.debug('Finished processing: {}'.format(obj))
finished.add(obj)
output.put(event)
def producer(obj):
try:
result = func(obj)
results.put((obj, result, None))
except Exception as e:
results.put((obj, None, e))
t = Thread(target=consumer)
t = Thread(target=queue_consumer, args=(objects, func, get_deps, results, output))
t.daemon = True
t.start()
return output
def queue_producer(obj, func, results):
try:
result = func(obj)
results.put((obj, result, None))
except Exception as e:
results.put((obj, None, e))
def queue_consumer(objects, func, get_deps, results, output):
started = set() # objects being processed
finished = set() # objects which have been processed
failed = set() # objects which either failed or whose dependencies failed
while len(finished) + len(failed) < len(objects):
pending = set(objects) - started - finished - failed
log.debug('Pending: {}'.format(pending))
for obj in pending:
deps = get_deps(obj)
if any(dep in failed for dep in deps):
log.debug('{} has upstream errors - not processing'.format(obj))
output.put((obj, None, UpstreamError()))
failed.add(obj)
elif all(
dep not in objects or dep in finished
for dep in deps
):
log.debug('Starting producer thread for {}'.format(obj))
t = Thread(target=queue_producer, args=(obj, func, results))
t.daemon = True
t.start()
started.add(obj)
try:
event = results.get(timeout=1)
except Empty:
continue
obj, _, exception = event
if exception is None:
log.debug('Finished processing: {}'.format(obj))
finished.add(obj)
else:
log.debug('Failed: {}'.format(obj))
failed.add(obj)
output.put(event)
class UpstreamError(Exception):
pass
class ParallelStreamWriter(object):
"""Write out messages for operations happening in parallel.

View File

@ -0,0 +1,73 @@
from __future__ import absolute_import
from __future__ import unicode_literals
import six
from docker.errors import APIError
from compose.parallel import parallel_execute
web = 'web'
db = 'db'
data_volume = 'data_volume'
cache = 'cache'
objects = [web, db, data_volume, cache]
deps = {
web: [db, cache],
db: [data_volume],
data_volume: [],
cache: [],
}
def test_parallel_execute():
results = parallel_execute(
objects=[1, 2, 3, 4, 5],
func=lambda x: x * 2,
get_name=six.text_type,
msg="Doubling",
)
assert sorted(results) == [2, 4, 6, 8, 10]
def test_parallel_execute_with_deps():
log = []
def process(x):
log.append(x)
parallel_execute(
objects=objects,
func=process,
get_name=lambda obj: obj,
msg="Processing",
get_deps=lambda obj: deps[obj],
)
assert sorted(log) == sorted(objects)
assert log.index(data_volume) < log.index(db)
assert log.index(db) < log.index(web)
assert log.index(cache) < log.index(web)
def test_parallel_execute_with_upstream_errors():
log = []
def process(x):
if x is data_volume:
raise APIError(None, None, "Something went wrong")
log.append(x)
parallel_execute(
objects=objects,
func=process,
get_name=lambda obj: obj,
msg="Processing",
get_deps=lambda obj: deps[obj],
)
assert log == [cache]