diff --git a/compose/container.py b/compose/container.py index 1ca483809..dde83bd35 100644 --- a/compose/container.py +++ b/compose/container.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from __future__ import unicode_literals +import operator from functools import reduce import six @@ -8,6 +9,7 @@ import six from .const import LABEL_CONTAINER_NUMBER from .const import LABEL_PROJECT from .const import LABEL_SERVICE +from compose.utils import parallel_execute class Container(object): @@ -250,3 +252,40 @@ def get_container_name(container): # ps shortest_name = min(container['Names'], key=lambda n: len(n.split('/'))) return shortest_name.split('/')[-1] + + +def parallel_operation(containers, operation, options, message): + parallel_execute( + containers, + operator.methodcaller(operation, **options), + operator.attrgetter('name'), + message) + + +def parallel_remove(containers, options): + stopped_containers = [c for c in containers if not c.is_running] + parallel_operation(stopped_containers, 'remove', options, 'Removing') + + +def parallel_stop(containers, options): + parallel_operation(containers, 'stop', options, 'Stopping') + + +def parallel_start(containers, options): + parallel_operation(containers, 'start', options, 'Starting') + + +def parallel_pause(containers, options): + parallel_operation(containers, 'pause', options, 'Pausing') + + +def parallel_unpause(containers, options): + parallel_operation(containers, 'unpause', options, 'Unpausing') + + +def parallel_kill(containers, options): + parallel_operation(containers, 'kill', options, 'Killing') + + +def parallel_restart(containers, options): + parallel_operation(containers, 'restart', options, 'Restarting') diff --git a/compose/project.py b/compose/project.py index 41af86261..dc6dd32fd 100644 --- a/compose/project.py +++ b/compose/project.py @@ -7,6 +7,7 @@ from functools import reduce from docker.errors import APIError from docker.errors import NotFound +from . import container from .config import ConfigurationError from .config import get_service_name_from_net from .const import DEFAULT_TIMEOUT @@ -22,7 +23,6 @@ from .service import parse_volume_from_spec from .service import Service from .service import ServiceNet from .service import VolumeFromSpec -from .utils import parallel_execute log = logging.getLogger(__name__) @@ -241,42 +241,22 @@ class Project(object): service.start(**options) def stop(self, service_names=None, **options): - parallel_execute( - objects=self.containers(service_names), - obj_callable=lambda c: c.stop(**options), - msg_index=lambda c: c.name, - msg="Stopping" - ) + container.parallel_stop(self.containers(service_names), options) def pause(self, service_names=None, **options): - for service in reversed(self.get_services(service_names)): - service.pause(**options) + container.parallel_pause(reversed(self.containers(service_names)), options) def unpause(self, service_names=None, **options): - for service in self.get_services(service_names): - service.unpause(**options) + container.parallel_unpause(self.containers(service_names), options) def kill(self, service_names=None, **options): - parallel_execute( - objects=self.containers(service_names), - obj_callable=lambda c: c.kill(**options), - msg_index=lambda c: c.name, - msg="Killing" - ) + container.parallel_kill(self.containers(service_names), options) def remove_stopped(self, service_names=None, **options): - all_containers = self.containers(service_names, stopped=True) - stopped_containers = [c for c in all_containers if not c.is_running] - parallel_execute( - objects=stopped_containers, - obj_callable=lambda c: c.remove(**options), - msg_index=lambda c: c.name, - msg="Removing" - ) + container.parallel_remove(self.containers(service_names, stopped=True), options) def restart(self, service_names=None, **options): - for service in self.get_services(service_names): - service.restart(**options) + container.parallel_restart(self.containers(service_names, stopped=True), options) def build(self, service_names=None, no_cache=False, pull=False, force_rm=False): for service in self.get_services(service_names): diff --git a/compose/service.py b/compose/service.py index b79fd9001..ab6f6dd6c 100644 --- a/compose/service.py +++ b/compose/service.py @@ -28,6 +28,9 @@ from .const import LABEL_PROJECT from .const import LABEL_SERVICE from .const import LABEL_VERSION from .container import Container +from .container import parallel_remove +from .container import parallel_start +from .container import parallel_stop from .legacy import check_for_legacy_containers from .progress_stream import stream_output from .progress_stream import StreamOutputError @@ -241,12 +244,7 @@ class Service(object): else: containers_to_start = stopped_containers - parallel_execute( - objects=containers_to_start, - obj_callable=lambda c: c.start(), - msg_index=lambda c: c.name, - msg="Starting" - ) + parallel_start(containers_to_start, {}) num_running += len(containers_to_start) @@ -259,35 +257,22 @@ class Service(object): ] parallel_execute( - objects=container_numbers, - obj_callable=lambda n: create_and_start(service=self, number=n), - msg_index=lambda n: n, - msg="Creating and starting" + container_numbers, + lambda n: create_and_start(service=self, number=n), + lambda n: n, + "Creating and starting" ) if desired_num < num_running: num_to_stop = num_running - desired_num - sorted_running_containers = sorted(running_containers, key=attrgetter('number')) - containers_to_stop = sorted_running_containers[-num_to_stop:] + sorted_running_containers = sorted( + running_containers, + key=attrgetter('number')) + parallel_stop( + sorted_running_containers[-num_to_stop:], + dict(timeout=timeout)) - parallel_execute( - objects=containers_to_stop, - obj_callable=lambda c: c.stop(timeout=timeout), - msg_index=lambda c: c.name, - msg="Stopping" - ) - - self.remove_stopped() - - def remove_stopped(self, **options): - containers = [c for c in self.containers(stopped=True) if not c.is_running] - - parallel_execute( - objects=containers, - obj_callable=lambda c: c.remove(**options), - msg_index=lambda c: c.name, - msg="Removing" - ) + parallel_remove(self.containers(stopped=True), {}) def create_container(self, one_off=False, diff --git a/compose/utils.py b/compose/utils.py index a013035e9..716f6633f 100644 --- a/compose/utils.py +++ b/compose/utils.py @@ -17,58 +17,51 @@ log = logging.getLogger(__name__) json_decoder = json.JSONDecoder() -def parallel_execute(objects, obj_callable, msg_index, msg): - """ - For a given list of objects, call the callable passing in the first +def perform_operation(func, arg, callback, index): + try: + callback((index, func(arg))) + except Exception as e: + callback((index, e)) + + +def parallel_execute(objects, func, index_func, msg): + """For a given list of objects, call the callable passing in the first object we give it. """ + objects = list(objects) stream = get_output_stream(sys.stdout) - lines = [] + writer = ParallelStreamWriter(stream, msg) for obj in objects: - write_out_msg(stream, lines, msg_index(obj), msg) + writer.initialize(index_func(obj)) q = Queue() - def inner_execute_function(an_callable, parameter, msg_index): - error = None - try: - result = an_callable(parameter) - except APIError as e: - error = e.explanation - result = "error" - except Exception as e: - error = e - result = 'unexpected_exception' - - q.put((msg_index, result, error)) - - for an_object in objects: + # TODO: limit the number of threads #1828 + for obj in objects: t = Thread( - target=inner_execute_function, - args=(obj_callable, an_object, msg_index(an_object)), - ) + target=perform_operation, + args=(func, obj, q.put, index_func(obj))) t.daemon = True t.start() done = 0 errors = {} - total_to_execute = len(objects) - while done < total_to_execute: + while done < len(objects): try: - msg_index, result, error = q.get(timeout=1) - - if result == 'unexpected_exception': - errors[msg_index] = result, error - if result == 'error': - errors[msg_index] = result, error - write_out_msg(stream, lines, msg_index, msg, status='error') - else: - write_out_msg(stream, lines, msg_index, msg) - done += 1 + msg_index, result = q.get(timeout=1) except Empty: - pass + continue + + if isinstance(result, APIError): + errors[msg_index] = "error", result.explanation + writer.write(msg_index, 'error') + elif isinstance(result, Exception): + errors[msg_index] = "unexpected_exception", result + else: + writer.write(msg_index, 'done') + done += 1 if not errors: return @@ -80,6 +73,36 @@ def parallel_execute(objects, obj_callable, msg_index, msg): raise error +class ParallelStreamWriter(object): + """Write out messages for operations happening in parallel. + + Each operation has it's own line, and ANSI code characters are used + to jump to the correct line, and write over the line. + """ + + def __init__(self, stream, msg): + self.stream = stream + self.msg = msg + self.lines = [] + + def initialize(self, obj_index): + self.lines.append(obj_index) + self.stream.write("{} {} ... \r\n".format(self.msg, obj_index)) + self.stream.flush() + + def write(self, obj_index, status): + position = self.lines.index(obj_index) + diff = len(self.lines) - position + # move up + self.stream.write("%c[%dA" % (27, diff)) + # erase + self.stream.write("%c[2K\r" % 27) + self.stream.write("{} {} ... {}\r".format(self.msg, obj_index, status)) + # move back down + self.stream.write("%c[%dB" % (27, diff)) + self.stream.flush() + + def get_output_stream(stream): if six.PY3: return stream @@ -151,30 +174,6 @@ def json_stream(stream): return split_buffer(stream, json_splitter, json_decoder.decode) -def write_out_msg(stream, lines, msg_index, msg, status="done"): - """ - Using special ANSI code characters we can write out the msg over the top of - a previous status message, if it exists. - """ - obj_index = msg_index - if msg_index in lines: - position = lines.index(obj_index) - diff = len(lines) - position - # move up - stream.write("%c[%dA" % (27, diff)) - # erase - stream.write("%c[2K\r" % 27) - stream.write("{} {} ... {}\r".format(msg, obj_index, status)) - # move back down - stream.write("%c[%dB" % (27, diff)) - else: - diff = 0 - lines.append(obj_index) - stream.write("{} {} ... \r\n".format(msg, obj_index)) - - stream.flush() - - def json_hash(obj): dump = json.dumps(obj, sort_keys=True, separators=(',', ':')) h = hashlib.sha256() diff --git a/tests/integration/service_test.py b/tests/integration/service_test.py index aaa4f01ec..34869ab88 100644 --- a/tests/integration/service_test.py +++ b/tests/integration/service_test.py @@ -36,6 +36,12 @@ def create_and_start_container(service, **override_options): return container +def remove_stopped(service): + containers = [c for c in service.containers(stopped=True) if not c.is_running] + for container in containers: + container.remove() + + class ServiceTest(DockerClientTestCase): def test_containers(self): foo = self.create_service('foo') @@ -94,14 +100,14 @@ class ServiceTest(DockerClientTestCase): create_and_start_container(service) self.assertEqual(len(service.containers()), 1) - service.remove_stopped() + remove_stopped(service) self.assertEqual(len(service.containers()), 1) service.kill() self.assertEqual(len(service.containers()), 0) self.assertEqual(len(service.containers(stopped=True)), 1) - service.remove_stopped() + remove_stopped(service) self.assertEqual(len(service.containers(stopped=True)), 0) def test_create_container_with_one_off(self): @@ -659,9 +665,8 @@ class ServiceTest(DockerClientTestCase): self.assertIn('Creating', captured_output) self.assertIn('Starting', captured_output) - def test_scale_with_api_returns_errors(self): - """ - Test that when scaling if the API returns an error, that error is handled + def test_scale_with_api_error(self): + """Test that when scaling if the API returns an error, that error is handled and the remaining threads continue. """ service = self.create_service('web') @@ -670,7 +675,10 @@ class ServiceTest(DockerClientTestCase): with mock.patch( 'compose.container.Container.create', - side_effect=APIError(message="testing", response={}, explanation="Boom")): + side_effect=APIError( + message="testing", + response={}, + explanation="Boom")): with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout: service.scale(3) @@ -679,9 +687,8 @@ class ServiceTest(DockerClientTestCase): self.assertTrue(service.containers()[0].is_running) self.assertIn("ERROR: for 2 Boom", mock_stdout.getvalue()) - def test_scale_with_api_returns_unexpected_exception(self): - """ - Test that when scaling if the API returns an error, that is not of type + def test_scale_with_unexpected_exception(self): + """Test that when scaling if the API returns an error, that is not of type APIError, that error is re-raised. """ service = self.create_service('web') @@ -903,7 +910,7 @@ class ServiceTest(DockerClientTestCase): self.assertIn(pair, labels) service.kill() - service.remove_stopped() + remove_stopped(service) labels_list = ["%s=%s" % pair for pair in labels_dict.items()] diff --git a/tox.ini b/tox.ini index d1098a55a..9d45b0c7f 100644 --- a/tox.ini +++ b/tox.ini @@ -44,5 +44,5 @@ directory = coverage-html # Allow really long lines for now max-line-length = 140 # Set this high for now -max-complexity = 20 +max-complexity = 12 exclude = compose/packages