Refactor parallel_execute.

Signed-off-by: Daniel Nephin <dnephin@docker.com>
This commit is contained in:
Daniel Nephin 2015-11-12 20:44:05 -05:00
parent c4096525c2
commit d1adbb9b25
6 changed files with 137 additions and 127 deletions

View File

@ -1,6 +1,7 @@
from __future__ import absolute_import
from __future__ import unicode_literals
import operator
from functools import reduce
import six
@ -8,6 +9,7 @@ import six
from .const import LABEL_CONTAINER_NUMBER
from .const import LABEL_PROJECT
from .const import LABEL_SERVICE
from compose.utils import parallel_execute
class Container(object):
@ -250,3 +252,40 @@ def get_container_name(container):
# ps
shortest_name = min(container['Names'], key=lambda n: len(n.split('/')))
return shortest_name.split('/')[-1]
def parallel_operation(containers, operation, options, message):
parallel_execute(
containers,
operator.methodcaller(operation, **options),
operator.attrgetter('name'),
message)
def parallel_remove(containers, options):
stopped_containers = [c for c in containers if not c.is_running]
parallel_operation(stopped_containers, 'remove', options, 'Removing')
def parallel_stop(containers, options):
parallel_operation(containers, 'stop', options, 'Stopping')
def parallel_start(containers, options):
parallel_operation(containers, 'start', options, 'Starting')
def parallel_pause(containers, options):
parallel_operation(containers, 'pause', options, 'Pausing')
def parallel_unpause(containers, options):
parallel_operation(containers, 'unpause', options, 'Unpausing')
def parallel_kill(containers, options):
parallel_operation(containers, 'kill', options, 'Killing')
def parallel_restart(containers, options):
parallel_operation(containers, 'restart', options, 'Restarting')

View File

@ -7,6 +7,7 @@ from functools import reduce
from docker.errors import APIError
from docker.errors import NotFound
from . import container
from .config import ConfigurationError
from .config import get_service_name_from_net
from .const import DEFAULT_TIMEOUT
@ -22,7 +23,6 @@ from .service import parse_volume_from_spec
from .service import Service
from .service import ServiceNet
from .service import VolumeFromSpec
from .utils import parallel_execute
log = logging.getLogger(__name__)
@ -241,42 +241,22 @@ class Project(object):
service.start(**options)
def stop(self, service_names=None, **options):
parallel_execute(
objects=self.containers(service_names),
obj_callable=lambda c: c.stop(**options),
msg_index=lambda c: c.name,
msg="Stopping"
)
container.parallel_stop(self.containers(service_names), options)
def pause(self, service_names=None, **options):
for service in reversed(self.get_services(service_names)):
service.pause(**options)
container.parallel_pause(reversed(self.containers(service_names)), options)
def unpause(self, service_names=None, **options):
for service in self.get_services(service_names):
service.unpause(**options)
container.parallel_unpause(self.containers(service_names), options)
def kill(self, service_names=None, **options):
parallel_execute(
objects=self.containers(service_names),
obj_callable=lambda c: c.kill(**options),
msg_index=lambda c: c.name,
msg="Killing"
)
container.parallel_kill(self.containers(service_names), options)
def remove_stopped(self, service_names=None, **options):
all_containers = self.containers(service_names, stopped=True)
stopped_containers = [c for c in all_containers if not c.is_running]
parallel_execute(
objects=stopped_containers,
obj_callable=lambda c: c.remove(**options),
msg_index=lambda c: c.name,
msg="Removing"
)
container.parallel_remove(self.containers(service_names, stopped=True), options)
def restart(self, service_names=None, **options):
for service in self.get_services(service_names):
service.restart(**options)
container.parallel_restart(self.containers(service_names, stopped=True), options)
def build(self, service_names=None, no_cache=False, pull=False, force_rm=False):
for service in self.get_services(service_names):

View File

@ -28,6 +28,9 @@ from .const import LABEL_PROJECT
from .const import LABEL_SERVICE
from .const import LABEL_VERSION
from .container import Container
from .container import parallel_remove
from .container import parallel_start
from .container import parallel_stop
from .legacy import check_for_legacy_containers
from .progress_stream import stream_output
from .progress_stream import StreamOutputError
@ -241,12 +244,7 @@ class Service(object):
else:
containers_to_start = stopped_containers
parallel_execute(
objects=containers_to_start,
obj_callable=lambda c: c.start(),
msg_index=lambda c: c.name,
msg="Starting"
)
parallel_start(containers_to_start, {})
num_running += len(containers_to_start)
@ -259,35 +257,22 @@ class Service(object):
]
parallel_execute(
objects=container_numbers,
obj_callable=lambda n: create_and_start(service=self, number=n),
msg_index=lambda n: n,
msg="Creating and starting"
container_numbers,
lambda n: create_and_start(service=self, number=n),
lambda n: n,
"Creating and starting"
)
if desired_num < num_running:
num_to_stop = num_running - desired_num
sorted_running_containers = sorted(running_containers, key=attrgetter('number'))
containers_to_stop = sorted_running_containers[-num_to_stop:]
sorted_running_containers = sorted(
running_containers,
key=attrgetter('number'))
parallel_stop(
sorted_running_containers[-num_to_stop:],
dict(timeout=timeout))
parallel_execute(
objects=containers_to_stop,
obj_callable=lambda c: c.stop(timeout=timeout),
msg_index=lambda c: c.name,
msg="Stopping"
)
self.remove_stopped()
def remove_stopped(self, **options):
containers = [c for c in self.containers(stopped=True) if not c.is_running]
parallel_execute(
objects=containers,
obj_callable=lambda c: c.remove(**options),
msg_index=lambda c: c.name,
msg="Removing"
)
parallel_remove(self.containers(stopped=True), {})
def create_container(self,
one_off=False,

View File

@ -17,58 +17,51 @@ log = logging.getLogger(__name__)
json_decoder = json.JSONDecoder()
def parallel_execute(objects, obj_callable, msg_index, msg):
"""
For a given list of objects, call the callable passing in the first
def perform_operation(func, arg, callback, index):
try:
callback((index, func(arg)))
except Exception as e:
callback((index, e))
def parallel_execute(objects, func, index_func, msg):
"""For a given list of objects, call the callable passing in the first
object we give it.
"""
objects = list(objects)
stream = get_output_stream(sys.stdout)
lines = []
writer = ParallelStreamWriter(stream, msg)
for obj in objects:
write_out_msg(stream, lines, msg_index(obj), msg)
writer.initialize(index_func(obj))
q = Queue()
def inner_execute_function(an_callable, parameter, msg_index):
error = None
try:
result = an_callable(parameter)
except APIError as e:
error = e.explanation
result = "error"
except Exception as e:
error = e
result = 'unexpected_exception'
q.put((msg_index, result, error))
for an_object in objects:
# TODO: limit the number of threads #1828
for obj in objects:
t = Thread(
target=inner_execute_function,
args=(obj_callable, an_object, msg_index(an_object)),
)
target=perform_operation,
args=(func, obj, q.put, index_func(obj)))
t.daemon = True
t.start()
done = 0
errors = {}
total_to_execute = len(objects)
while done < total_to_execute:
while done < len(objects):
try:
msg_index, result, error = q.get(timeout=1)
if result == 'unexpected_exception':
errors[msg_index] = result, error
if result == 'error':
errors[msg_index] = result, error
write_out_msg(stream, lines, msg_index, msg, status='error')
else:
write_out_msg(stream, lines, msg_index, msg)
done += 1
msg_index, result = q.get(timeout=1)
except Empty:
pass
continue
if isinstance(result, APIError):
errors[msg_index] = "error", result.explanation
writer.write(msg_index, 'error')
elif isinstance(result, Exception):
errors[msg_index] = "unexpected_exception", result
else:
writer.write(msg_index, 'done')
done += 1
if not errors:
return
@ -80,6 +73,36 @@ def parallel_execute(objects, obj_callable, msg_index, msg):
raise error
class ParallelStreamWriter(object):
"""Write out messages for operations happening in parallel.
Each operation has it's own line, and ANSI code characters are used
to jump to the correct line, and write over the line.
"""
def __init__(self, stream, msg):
self.stream = stream
self.msg = msg
self.lines = []
def initialize(self, obj_index):
self.lines.append(obj_index)
self.stream.write("{} {} ... \r\n".format(self.msg, obj_index))
self.stream.flush()
def write(self, obj_index, status):
position = self.lines.index(obj_index)
diff = len(self.lines) - position
# move up
self.stream.write("%c[%dA" % (27, diff))
# erase
self.stream.write("%c[2K\r" % 27)
self.stream.write("{} {} ... {}\r".format(self.msg, obj_index, status))
# move back down
self.stream.write("%c[%dB" % (27, diff))
self.stream.flush()
def get_output_stream(stream):
if six.PY3:
return stream
@ -151,30 +174,6 @@ def json_stream(stream):
return split_buffer(stream, json_splitter, json_decoder.decode)
def write_out_msg(stream, lines, msg_index, msg, status="done"):
"""
Using special ANSI code characters we can write out the msg over the top of
a previous status message, if it exists.
"""
obj_index = msg_index
if msg_index in lines:
position = lines.index(obj_index)
diff = len(lines) - position
# move up
stream.write("%c[%dA" % (27, diff))
# erase
stream.write("%c[2K\r" % 27)
stream.write("{} {} ... {}\r".format(msg, obj_index, status))
# move back down
stream.write("%c[%dB" % (27, diff))
else:
diff = 0
lines.append(obj_index)
stream.write("{} {} ... \r\n".format(msg, obj_index))
stream.flush()
def json_hash(obj):
dump = json.dumps(obj, sort_keys=True, separators=(',', ':'))
h = hashlib.sha256()

View File

@ -36,6 +36,12 @@ def create_and_start_container(service, **override_options):
return container
def remove_stopped(service):
containers = [c for c in service.containers(stopped=True) if not c.is_running]
for container in containers:
container.remove()
class ServiceTest(DockerClientTestCase):
def test_containers(self):
foo = self.create_service('foo')
@ -94,14 +100,14 @@ class ServiceTest(DockerClientTestCase):
create_and_start_container(service)
self.assertEqual(len(service.containers()), 1)
service.remove_stopped()
remove_stopped(service)
self.assertEqual(len(service.containers()), 1)
service.kill()
self.assertEqual(len(service.containers()), 0)
self.assertEqual(len(service.containers(stopped=True)), 1)
service.remove_stopped()
remove_stopped(service)
self.assertEqual(len(service.containers(stopped=True)), 0)
def test_create_container_with_one_off(self):
@ -659,9 +665,8 @@ class ServiceTest(DockerClientTestCase):
self.assertIn('Creating', captured_output)
self.assertIn('Starting', captured_output)
def test_scale_with_api_returns_errors(self):
"""
Test that when scaling if the API returns an error, that error is handled
def test_scale_with_api_error(self):
"""Test that when scaling if the API returns an error, that error is handled
and the remaining threads continue.
"""
service = self.create_service('web')
@ -670,7 +675,10 @@ class ServiceTest(DockerClientTestCase):
with mock.patch(
'compose.container.Container.create',
side_effect=APIError(message="testing", response={}, explanation="Boom")):
side_effect=APIError(
message="testing",
response={},
explanation="Boom")):
with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout:
service.scale(3)
@ -679,9 +687,8 @@ class ServiceTest(DockerClientTestCase):
self.assertTrue(service.containers()[0].is_running)
self.assertIn("ERROR: for 2 Boom", mock_stdout.getvalue())
def test_scale_with_api_returns_unexpected_exception(self):
"""
Test that when scaling if the API returns an error, that is not of type
def test_scale_with_unexpected_exception(self):
"""Test that when scaling if the API returns an error, that is not of type
APIError, that error is re-raised.
"""
service = self.create_service('web')
@ -903,7 +910,7 @@ class ServiceTest(DockerClientTestCase):
self.assertIn(pair, labels)
service.kill()
service.remove_stopped()
remove_stopped(service)
labels_list = ["%s=%s" % pair for pair in labels_dict.items()]

View File

@ -44,5 +44,5 @@ directory = coverage-html
# Allow really long lines for now
max-line-length = 140
# Set this high for now
max-complexity = 20
max-complexity = 12
exclude = compose/packages