diff --git a/compose/config/config.py b/compose/config/config.py index fa214767b..84b6748c9 100644 --- a/compose/config/config.py +++ b/compose/config/config.py @@ -1,5 +1,6 @@ import codecs import logging +import operator import os import sys from collections import namedtuple @@ -387,56 +388,46 @@ def merge_service_dicts_from_files(base, override): def merge_service_dicts(base, override): - d = base.copy() + d = {} - if 'environment' in base or 'environment' in override: - d['environment'] = merge_environment( - base.get('environment'), - override.get('environment'), - ) + def merge_field(field, merge_func, default=None): + if field in base or field in override: + d[field] = merge_func( + base.get(field, default), + override.get(field, default)) - path_mapping_keys = ['volumes', 'devices'] + merge_field('environment', merge_environment) + merge_field('labels', merge_labels) + merge_image_or_build(base, override, d) - for key in path_mapping_keys: - if key in base or key in override: - d[key] = merge_path_mappings( - base.get(key), - override.get(key), - ) + for field in ['volumes', 'devices']: + merge_field(field, merge_path_mappings) - if 'labels' in base or 'labels' in override: - d['labels'] = merge_labels( - base.get('labels'), - override.get('labels'), - ) + for field in ['ports', 'expose', 'external_links']: + merge_field(field, operator.add, default=[]) - if 'image' in override and 'build' in d: - del d['build'] + for field in ['dns', 'dns_search']: + merge_field(field, merge_list_or_string) - if 'build' in override and 'image' in d: - del d['image'] - - list_keys = ['ports', 'expose', 'external_links'] - - for key in list_keys: - if key in base or key in override: - d[key] = base.get(key, []) + override.get(key, []) - - list_or_string_keys = ['dns', 'dns_search'] - - for key in list_or_string_keys: - if key in base or key in override: - d[key] = to_list(base.get(key)) + to_list(override.get(key)) - - already_merged_keys = ['environment', 'labels'] + path_mapping_keys + list_keys + list_or_string_keys - - for k in set(ALLOWED_KEYS) - set(already_merged_keys): - if k in override: - d[k] = override[k] + already_merged_keys = set(d) | {'image', 'build'} + for field in set(ALLOWED_KEYS) - already_merged_keys: + if field in base or field in override: + d[field] = override.get(field, base.get(field)) return d +def merge_image_or_build(base, override, output): + if 'image' in override: + output['image'] = override['image'] + elif 'build' in override: + output['build'] = override['build'] + elif 'image' in base: + output['image'] = base['image'] + elif 'build' in base: + output['build'] = base['build'] + + def merge_environment(base, override): env = parse_environment(base) env.update(parse_environment(override)) @@ -602,6 +593,10 @@ def expand_path(working_dir, path): return os.path.abspath(os.path.join(working_dir, os.path.expanduser(path))) +def merge_list_or_string(base, override): + return to_list(base) + to_list(override) + + def to_list(value): if value is None: return [] diff --git a/compose/parallel.py b/compose/parallel.py new file mode 100644 index 000000000..2735a397f --- /dev/null +++ b/compose/parallel.py @@ -0,0 +1,135 @@ +from __future__ import absolute_import +from __future__ import unicode_literals + +import operator +import sys +from threading import Thread + +from docker.errors import APIError +from six.moves.queue import Empty +from six.moves.queue import Queue + +from compose.utils import get_output_stream + + +def perform_operation(func, arg, callback, index): + try: + callback((index, func(arg))) + except Exception as e: + callback((index, e)) + + +def parallel_execute(objects, func, index_func, msg): + """For a given list of objects, call the callable passing in the first + object we give it. + """ + objects = list(objects) + stream = get_output_stream(sys.stdout) + writer = ParallelStreamWriter(stream, msg) + + for obj in objects: + writer.initialize(index_func(obj)) + + q = Queue() + + # TODO: limit the number of threads #1828 + for obj in objects: + t = Thread( + target=perform_operation, + args=(func, obj, q.put, index_func(obj))) + t.daemon = True + t.start() + + done = 0 + errors = {} + + while done < len(objects): + try: + msg_index, result = q.get(timeout=1) + except Empty: + continue + + if isinstance(result, APIError): + errors[msg_index] = "error", result.explanation + writer.write(msg_index, 'error') + elif isinstance(result, Exception): + errors[msg_index] = "unexpected_exception", result + else: + writer.write(msg_index, 'done') + done += 1 + + if not errors: + return + + stream.write("\n") + for msg_index, (result, error) in errors.items(): + stream.write("ERROR: for {} {} \n".format(msg_index, error)) + if result == 'unexpected_exception': + raise error + + +class ParallelStreamWriter(object): + """Write out messages for operations happening in parallel. + + Each operation has it's own line, and ANSI code characters are used + to jump to the correct line, and write over the line. + """ + + def __init__(self, stream, msg): + self.stream = stream + self.msg = msg + self.lines = [] + + def initialize(self, obj_index): + self.lines.append(obj_index) + self.stream.write("{} {} ... \r\n".format(self.msg, obj_index)) + self.stream.flush() + + def write(self, obj_index, status): + position = self.lines.index(obj_index) + diff = len(self.lines) - position + # move up + self.stream.write("%c[%dA" % (27, diff)) + # erase + self.stream.write("%c[2K\r" % 27) + self.stream.write("{} {} ... {}\r".format(self.msg, obj_index, status)) + # move back down + self.stream.write("%c[%dB" % (27, diff)) + self.stream.flush() + + +def parallel_operation(containers, operation, options, message): + parallel_execute( + containers, + operator.methodcaller(operation, **options), + operator.attrgetter('name'), + message) + + +def parallel_remove(containers, options): + stopped_containers = [c for c in containers if not c.is_running] + parallel_operation(stopped_containers, 'remove', options, 'Removing') + + +def parallel_stop(containers, options): + parallel_operation(containers, 'stop', options, 'Stopping') + + +def parallel_start(containers, options): + parallel_operation(containers, 'start', options, 'Starting') + + +def parallel_pause(containers, options): + parallel_operation(containers, 'pause', options, 'Pausing') + + +def parallel_unpause(containers, options): + parallel_operation(containers, 'unpause', options, 'Unpausing') + + +def parallel_kill(containers, options): + parallel_operation(containers, 'kill', options, 'Killing') + + +def parallel_restart(containers, options): + parallel_operation(containers, 'restart', options, 'Restarting') diff --git a/compose/project.py b/compose/project.py index 41af86261..e29a2eb5a 100644 --- a/compose/project.py +++ b/compose/project.py @@ -7,6 +7,7 @@ from functools import reduce from docker.errors import APIError from docker.errors import NotFound +from . import parallel from .config import ConfigurationError from .config import get_service_name_from_net from .const import DEFAULT_TIMEOUT @@ -22,7 +23,6 @@ from .service import parse_volume_from_spec from .service import Service from .service import ServiceNet from .service import VolumeFromSpec -from .utils import parallel_execute log = logging.getLogger(__name__) @@ -241,42 +241,22 @@ class Project(object): service.start(**options) def stop(self, service_names=None, **options): - parallel_execute( - objects=self.containers(service_names), - obj_callable=lambda c: c.stop(**options), - msg_index=lambda c: c.name, - msg="Stopping" - ) + parallel.parallel_stop(self.containers(service_names), options) def pause(self, service_names=None, **options): - for service in reversed(self.get_services(service_names)): - service.pause(**options) + parallel.parallel_pause(reversed(self.containers(service_names)), options) def unpause(self, service_names=None, **options): - for service in self.get_services(service_names): - service.unpause(**options) + parallel.parallel_unpause(self.containers(service_names), options) def kill(self, service_names=None, **options): - parallel_execute( - objects=self.containers(service_names), - obj_callable=lambda c: c.kill(**options), - msg_index=lambda c: c.name, - msg="Killing" - ) + parallel.parallel_kill(self.containers(service_names), options) def remove_stopped(self, service_names=None, **options): - all_containers = self.containers(service_names, stopped=True) - stopped_containers = [c for c in all_containers if not c.is_running] - parallel_execute( - objects=stopped_containers, - obj_callable=lambda c: c.remove(**options), - msg_index=lambda c: c.name, - msg="Removing" - ) + parallel.parallel_remove(self.containers(service_names, stopped=True), options) def restart(self, service_names=None, **options): - for service in self.get_services(service_names): - service.restart(**options) + parallel.parallel_restart(self.containers(service_names, stopped=True), options) def build(self, service_names=None, no_cache=False, pull=False, force_rm=False): for service in self.get_services(service_names): diff --git a/compose/service.py b/compose/service.py index b79fd9001..dd2399ee3 100644 --- a/compose/service.py +++ b/compose/service.py @@ -29,10 +29,13 @@ from .const import LABEL_SERVICE from .const import LABEL_VERSION from .container import Container from .legacy import check_for_legacy_containers +from .parallel import parallel_execute +from .parallel import parallel_remove +from .parallel import parallel_start +from .parallel import parallel_stop from .progress_stream import stream_output from .progress_stream import StreamOutputError from .utils import json_hash -from .utils import parallel_execute log = logging.getLogger(__name__) @@ -241,12 +244,7 @@ class Service(object): else: containers_to_start = stopped_containers - parallel_execute( - objects=containers_to_start, - obj_callable=lambda c: c.start(), - msg_index=lambda c: c.name, - msg="Starting" - ) + parallel_start(containers_to_start, {}) num_running += len(containers_to_start) @@ -259,35 +257,22 @@ class Service(object): ] parallel_execute( - objects=container_numbers, - obj_callable=lambda n: create_and_start(service=self, number=n), - msg_index=lambda n: n, - msg="Creating and starting" + container_numbers, + lambda n: create_and_start(service=self, number=n), + lambda n: n, + "Creating and starting" ) if desired_num < num_running: num_to_stop = num_running - desired_num - sorted_running_containers = sorted(running_containers, key=attrgetter('number')) - containers_to_stop = sorted_running_containers[-num_to_stop:] + sorted_running_containers = sorted( + running_containers, + key=attrgetter('number')) + parallel_stop( + sorted_running_containers[-num_to_stop:], + dict(timeout=timeout)) - parallel_execute( - objects=containers_to_stop, - obj_callable=lambda c: c.stop(timeout=timeout), - msg_index=lambda c: c.name, - msg="Stopping" - ) - - self.remove_stopped() - - def remove_stopped(self, **options): - containers = [c for c in self.containers(stopped=True) if not c.is_running] - - parallel_execute( - objects=containers, - obj_callable=lambda c: c.remove(**options), - msg_index=lambda c: c.name, - msg="Removing" - ) + parallel_remove(self.containers(stopped=True), {}) def create_container(self, one_off=False, diff --git a/compose/utils.py b/compose/utils.py index a013035e9..362629bc2 100644 --- a/compose/utils.py +++ b/compose/utils.py @@ -2,84 +2,13 @@ import codecs import hashlib import json import json.decoder -import logging -import sys -from threading import Thread import six -from docker.errors import APIError -from six.moves.queue import Empty -from six.moves.queue import Queue -log = logging.getLogger(__name__) - json_decoder = json.JSONDecoder() -def parallel_execute(objects, obj_callable, msg_index, msg): - """ - For a given list of objects, call the callable passing in the first - object we give it. - """ - stream = get_output_stream(sys.stdout) - lines = [] - - for obj in objects: - write_out_msg(stream, lines, msg_index(obj), msg) - - q = Queue() - - def inner_execute_function(an_callable, parameter, msg_index): - error = None - try: - result = an_callable(parameter) - except APIError as e: - error = e.explanation - result = "error" - except Exception as e: - error = e - result = 'unexpected_exception' - - q.put((msg_index, result, error)) - - for an_object in objects: - t = Thread( - target=inner_execute_function, - args=(obj_callable, an_object, msg_index(an_object)), - ) - t.daemon = True - t.start() - - done = 0 - errors = {} - total_to_execute = len(objects) - - while done < total_to_execute: - try: - msg_index, result, error = q.get(timeout=1) - - if result == 'unexpected_exception': - errors[msg_index] = result, error - if result == 'error': - errors[msg_index] = result, error - write_out_msg(stream, lines, msg_index, msg, status='error') - else: - write_out_msg(stream, lines, msg_index, msg) - done += 1 - except Empty: - pass - - if not errors: - return - - stream.write("\n") - for msg_index, (result, error) in errors.items(): - stream.write("ERROR: for {} {} \n".format(msg_index, error)) - if result == 'unexpected_exception': - raise error - - def get_output_stream(stream): if six.PY3: return stream @@ -151,30 +80,6 @@ def json_stream(stream): return split_buffer(stream, json_splitter, json_decoder.decode) -def write_out_msg(stream, lines, msg_index, msg, status="done"): - """ - Using special ANSI code characters we can write out the msg over the top of - a previous status message, if it exists. - """ - obj_index = msg_index - if msg_index in lines: - position = lines.index(obj_index) - diff = len(lines) - position - # move up - stream.write("%c[%dA" % (27, diff)) - # erase - stream.write("%c[2K\r" % 27) - stream.write("{} {} ... {}\r".format(msg, obj_index, status)) - # move back down - stream.write("%c[%dB" % (27, diff)) - else: - diff = 0 - lines.append(obj_index) - stream.write("{} {} ... \r\n".format(msg, obj_index)) - - stream.flush() - - def json_hash(obj): dump = json.dumps(obj, sort_keys=True, separators=(',', ':')) h = hashlib.sha256() diff --git a/tests/integration/service_test.py b/tests/integration/service_test.py index aaa4f01ec..34869ab88 100644 --- a/tests/integration/service_test.py +++ b/tests/integration/service_test.py @@ -36,6 +36,12 @@ def create_and_start_container(service, **override_options): return container +def remove_stopped(service): + containers = [c for c in service.containers(stopped=True) if not c.is_running] + for container in containers: + container.remove() + + class ServiceTest(DockerClientTestCase): def test_containers(self): foo = self.create_service('foo') @@ -94,14 +100,14 @@ class ServiceTest(DockerClientTestCase): create_and_start_container(service) self.assertEqual(len(service.containers()), 1) - service.remove_stopped() + remove_stopped(service) self.assertEqual(len(service.containers()), 1) service.kill() self.assertEqual(len(service.containers()), 0) self.assertEqual(len(service.containers(stopped=True)), 1) - service.remove_stopped() + remove_stopped(service) self.assertEqual(len(service.containers(stopped=True)), 0) def test_create_container_with_one_off(self): @@ -659,9 +665,8 @@ class ServiceTest(DockerClientTestCase): self.assertIn('Creating', captured_output) self.assertIn('Starting', captured_output) - def test_scale_with_api_returns_errors(self): - """ - Test that when scaling if the API returns an error, that error is handled + def test_scale_with_api_error(self): + """Test that when scaling if the API returns an error, that error is handled and the remaining threads continue. """ service = self.create_service('web') @@ -670,7 +675,10 @@ class ServiceTest(DockerClientTestCase): with mock.patch( 'compose.container.Container.create', - side_effect=APIError(message="testing", response={}, explanation="Boom")): + side_effect=APIError( + message="testing", + response={}, + explanation="Boom")): with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout: service.scale(3) @@ -679,9 +687,8 @@ class ServiceTest(DockerClientTestCase): self.assertTrue(service.containers()[0].is_running) self.assertIn("ERROR: for 2 Boom", mock_stdout.getvalue()) - def test_scale_with_api_returns_unexpected_exception(self): - """ - Test that when scaling if the API returns an error, that is not of type + def test_scale_with_unexpected_exception(self): + """Test that when scaling if the API returns an error, that is not of type APIError, that error is re-raised. """ service = self.create_service('web') @@ -903,7 +910,7 @@ class ServiceTest(DockerClientTestCase): self.assertIn(pair, labels) service.kill() - service.remove_stopped() + remove_stopped(service) labels_list = ["%s=%s" % pair for pair in labels_dict.items()] diff --git a/tox.ini b/tox.ini index d1098a55a..9d45b0c7f 100644 --- a/tox.ini +++ b/tox.ini @@ -44,5 +44,5 @@ directory = coverage-html # Allow really long lines for now max-line-length = 140 # Set this high for now -max-complexity = 20 +max-complexity = 12 exclude = compose/packages