diff --git a/compose/cli/command.py b/compose/cli/command.py index 6bfc23019..7cea91a2d 100644 --- a/compose/cli/command.py +++ b/compose/cli/command.py @@ -10,6 +10,7 @@ import six from . import errors from . import verbose_proxy from .. import config +from .. import parallel from ..config.environment import Environment from ..const import API_VERSIONS from ..project import Project @@ -23,6 +24,8 @@ log = logging.getLogger(__name__) def project_from_options(project_dir, options): environment = Environment.from_env_file(project_dir) + set_parallel_limit(environment) + host = options.get('--host') if host is not None: host = host.lstrip('=') @@ -38,6 +41,22 @@ def project_from_options(project_dir, options): ) +def set_parallel_limit(environment): + parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT') + if parallel_limit: + try: + parallel_limit = int(parallel_limit) + except ValueError: + raise errors.UserError( + 'COMPOSE_PARALLEL_LIMIT must be an integer (found: "{}")'.format( + environment.get('COMPOSE_PARALLEL_LIMIT') + ) + ) + if parallel_limit <= 1: + raise errors.UserError('COMPOSE_PARALLEL_LIMIT can not be less than 2') + parallel.GlobalLimit.set_global_limit(parallel_limit) + + def get_config_from_options(base_dir, options): environment = Environment.from_env_file(base_dir) config_path = get_config_path_from_options( @@ -99,13 +118,8 @@ def get_project(project_dir, config_path=None, project_name=None, verbose=False, host=host, environment=environment ) - global_parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT') - if global_parallel_limit: - global_parallel_limit = int(global_parallel_limit) - with errors.handle_connection_errors(client): - return Project.from_config(project_name, config_data, client, - global_parallel_limit=global_parallel_limit) + return Project.from_config(project_name, config_data, client) def get_project_name(working_dir, project_name=None, environment=None): diff --git a/compose/parallel.py b/compose/parallel.py index 382ce0251..3c0098c05 100644 --- a/compose/parallel.py +++ b/compose/parallel.py @@ -35,9 +35,10 @@ class GlobalLimit(object): global_limiter = Semaphore(PARALLEL_LIMIT) @classmethod - def set_global_limit(cls, value=None): - if value is not None: - cls.global_limiter = Semaphore(value) + def set_global_limit(cls, value): + if value is None: + value = PARALLEL_LIMIT + cls.global_limiter = Semaphore(value) def parallel_execute(objects, func, get_name, msg, get_deps=None, limit=None, parent_objects=None): diff --git a/compose/project.py b/compose/project.py index c5bdbb16c..11ee4a0b7 100644 --- a/compose/project.py +++ b/compose/project.py @@ -61,15 +61,13 @@ class Project(object): """ A collection of services. """ - def __init__(self, name, services, client, networks=None, volumes=None, config_version=None, - parallel_limit=None): + def __init__(self, name, services, client, networks=None, volumes=None, config_version=None): self.name = name self.services = services self.client = client self.volumes = volumes or ProjectVolumes({}) self.networks = networks or ProjectNetworks({}, False) self.config_version = config_version - parallel.GlobalLimit.set_global_limit(value=parallel_limit) def labels(self, one_off=OneOffFilter.exclude): labels = ['{0}={1}'.format(LABEL_PROJECT, self.name)] @@ -78,7 +76,7 @@ class Project(object): return labels @classmethod - def from_config(cls, name, config_data, client, global_parallel_limit=None): + def from_config(cls, name, config_data, client): """ Construct a Project from a config.Config object. """ @@ -89,8 +87,7 @@ class Project(object): networks, use_networking) volumes = ProjectVolumes.from_config(name, config_data, client) - project = cls(name, [], client, project_networks, volumes, config_data.version, - parallel_limit=global_parallel_limit) + project = cls(name, [], client, project_networks, volumes, config_data.version) for service_dict in config_data.services: service_dict = dict(service_dict) diff --git a/tests/unit/parallel_test.py b/tests/unit/parallel_test.py index 8ac6b339a..4ebc24d8c 100644 --- a/tests/unit/parallel_test.py +++ b/tests/unit/parallel_test.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from __future__ import unicode_literals +import unittest from threading import Lock import six @@ -32,114 +33,113 @@ def get_deps(obj): return [(dep, None) for dep in deps[obj]] -def test_parallel_execute(): - results, errors = parallel_execute( - objects=[1, 2, 3, 4, 5], - func=lambda x: x * 2, - get_name=six.text_type, - msg="Doubling", - ) +class ParallelTest(unittest.TestCase): - assert sorted(results) == [2, 4, 6, 8, 10] - assert errors == {} + def test_parallel_execute(self): + results, errors = parallel_execute( + objects=[1, 2, 3, 4, 5], + func=lambda x: x * 2, + get_name=six.text_type, + msg="Doubling", + ) + assert sorted(results) == [2, 4, 6, 8, 10] + assert errors == {} -def test_parallel_execute_with_limit(): - limit = 1 - tasks = 20 - lock = Lock() + def test_parallel_execute_with_limit(self): + limit = 1 + tasks = 20 + lock = Lock() - def f(obj): - locked = lock.acquire(False) - # we should always get the lock because we're the only thread running - assert locked - lock.release() - return None + def f(obj): + locked = lock.acquire(False) + # we should always get the lock because we're the only thread running + assert locked + lock.release() + return None - results, errors = parallel_execute( - objects=list(range(tasks)), - func=f, - get_name=six.text_type, - msg="Testing", - limit=limit, - ) + results, errors = parallel_execute( + objects=list(range(tasks)), + func=f, + get_name=six.text_type, + msg="Testing", + limit=limit, + ) - assert results == tasks * [None] - assert errors == {} + assert results == tasks * [None] + assert errors == {} + def test_parallel_execute_with_global_limit(self): + GlobalLimit.set_global_limit(1) + self.addCleanup(GlobalLimit.set_global_limit, None) + tasks = 20 + lock = Lock() -def test_parallel_execute_with_global_limit(): - GlobalLimit.set_global_limit(1) - tasks = 20 - lock = Lock() + def f(obj): + locked = lock.acquire(False) + # we should always get the lock because we're the only thread running + assert locked + lock.release() + return None - def f(obj): - locked = lock.acquire(False) - # we should always get the lock because we're the only thread running - assert locked - lock.release() - return None + results, errors = parallel_execute( + objects=list(range(tasks)), + func=f, + get_name=six.text_type, + msg="Testing", + ) - results, errors = parallel_execute( - objects=list(range(tasks)), - func=f, - get_name=six.text_type, - msg="Testing", - ) + assert results == tasks * [None] + assert errors == {} - assert results == tasks * [None] - assert errors == {} + def test_parallel_execute_with_deps(self): + log = [] + def process(x): + log.append(x) -def test_parallel_execute_with_deps(): - log = [] + parallel_execute( + objects=objects, + func=process, + get_name=lambda obj: obj, + msg="Processing", + get_deps=get_deps, + ) - def process(x): - log.append(x) + assert sorted(log) == sorted(objects) - parallel_execute( - objects=objects, - func=process, - get_name=lambda obj: obj, - msg="Processing", - get_deps=get_deps, - ) + assert log.index(data_volume) < log.index(db) + assert log.index(db) < log.index(web) + assert log.index(cache) < log.index(web) - assert sorted(log) == sorted(objects) + def test_parallel_execute_with_upstream_errors(self): + log = [] - assert log.index(data_volume) < log.index(db) - assert log.index(db) < log.index(web) - assert log.index(cache) < log.index(web) + def process(x): + if x is data_volume: + raise APIError(None, None, "Something went wrong") + log.append(x) + parallel_execute( + objects=objects, + func=process, + get_name=lambda obj: obj, + msg="Processing", + get_deps=get_deps, + ) -def test_parallel_execute_with_upstream_errors(): - log = [] + assert log == [cache] - def process(x): - if x is data_volume: - raise APIError(None, None, "Something went wrong") - log.append(x) + events = [ + (obj, result, type(exception)) + for obj, result, exception + in parallel_execute_iter(objects, process, get_deps, None) + ] - parallel_execute( - objects=objects, - func=process, - get_name=lambda obj: obj, - msg="Processing", - get_deps=get_deps, - ) - - assert log == [cache] - - events = [ - (obj, result, type(exception)) - for obj, result, exception - in parallel_execute_iter(objects, process, get_deps, None) - ] - - assert (cache, None, type(None)) in events - assert (data_volume, None, APIError) in events - assert (db, None, UpstreamError) in events - assert (web, None, UpstreamError) in events + assert (cache, None, type(None)) in events + assert (data_volume, None, APIError) in events + assert (db, None, UpstreamError) in events + assert (web, None, UpstreamError) in events def test_parallel_execute_alignment(capsys):