diff --git a/compose/cli/command.py b/compose/cli/command.py index e1ae690c0..6bfc23019 100644 --- a/compose/cli/command.py +++ b/compose/cli/command.py @@ -99,8 +99,13 @@ def get_project(project_dir, config_path=None, project_name=None, verbose=False, host=host, environment=environment ) + global_parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT') + if global_parallel_limit: + global_parallel_limit = int(global_parallel_limit) + with errors.handle_connection_errors(client): - return Project.from_config(project_name, config_data, client) + return Project.from_config(project_name, config_data, client, + global_parallel_limit=global_parallel_limit) def get_project_name(working_dir, project_name=None, environment=None): diff --git a/compose/parallel.py b/compose/parallel.py index 4f881c8f1..382ce0251 100644 --- a/compose/parallel.py +++ b/compose/parallel.py @@ -15,7 +15,6 @@ from six.moves.queue import Queue from compose.cli.colors import green from compose.cli.colors import red from compose.cli.signals import ShutdownException -from compose.config.environment import Environment from compose.const import PARALLEL_LIMIT from compose.errors import HealthCheckFailed from compose.errors import NoHealthCheckConfigured @@ -28,16 +27,17 @@ log = logging.getLogger(__name__) STOP = object() -def get_configured_limit(): - limit = Environment.from_command_line({'COMPOSE_PARALLEL_LIMIT': None})['COMPOSE_PARALLEL_LIMIT'] - if limit: - limit = int(limit) - else: - limit = PARALLEL_LIMIT - return limit +class GlobalLimit(object): + """Simple class to hold a global semaphore limiter for a project. This class + should be treated as a singleton that is instantiated when the project is. + """ + global_limiter = Semaphore(PARALLEL_LIMIT) -global_limiter = Semaphore(get_configured_limit()) + @classmethod + def set_global_limit(cls, value=None): + if value is not None: + cls.global_limiter = Semaphore(value) def parallel_execute(objects, func, get_name, msg, get_deps=None, limit=None, parent_objects=None): @@ -187,7 +187,7 @@ def producer(obj, func, results, limiter): The entry point for a producer thread which runs func on a single object. Places a tuple on the results queue once func has either returned or raised. """ - with limiter, global_limiter: + with limiter, GlobalLimit.global_limiter: try: result = func(obj) results.put((obj, result, None)) diff --git a/compose/project.py b/compose/project.py index 11ee4a0b7..c5bdbb16c 100644 --- a/compose/project.py +++ b/compose/project.py @@ -61,13 +61,15 @@ class Project(object): """ A collection of services. """ - def __init__(self, name, services, client, networks=None, volumes=None, config_version=None): + def __init__(self, name, services, client, networks=None, volumes=None, config_version=None, + parallel_limit=None): self.name = name self.services = services self.client = client self.volumes = volumes or ProjectVolumes({}) self.networks = networks or ProjectNetworks({}, False) self.config_version = config_version + parallel.GlobalLimit.set_global_limit(value=parallel_limit) def labels(self, one_off=OneOffFilter.exclude): labels = ['{0}={1}'.format(LABEL_PROJECT, self.name)] @@ -76,7 +78,7 @@ class Project(object): return labels @classmethod - def from_config(cls, name, config_data, client): + def from_config(cls, name, config_data, client, global_parallel_limit=None): """ Construct a Project from a config.Config object. """ @@ -87,7 +89,8 @@ class Project(object): networks, use_networking) volumes = ProjectVolumes.from_config(name, config_data, client) - project = cls(name, [], client, project_networks, volumes, config_data.version) + project = cls(name, [], client, project_networks, volumes, config_data.version, + parallel_limit=global_parallel_limit) for service_dict in config_data.services: service_dict = dict(service_dict) diff --git a/tests/unit/parallel_test.py b/tests/unit/parallel_test.py index 529395aa3..8ac6b339a 100644 --- a/tests/unit/parallel_test.py +++ b/tests/unit/parallel_test.py @@ -1,14 +1,12 @@ from __future__ import absolute_import from __future__ import unicode_literals -import os from threading import Lock import six from docker.errors import APIError -from .. import mock -from compose.parallel import get_configured_limit +from compose.parallel import GlobalLimit from compose.parallel import parallel_execute from compose.parallel import parallel_execute_iter from compose.parallel import ParallelStreamWriter @@ -70,14 +68,11 @@ def test_parallel_execute_with_limit(): assert errors == {} -@mock.patch.dict(os.environ) def test_parallel_execute_with_global_limit(): - os.environ['COMPOSE_PARALLEL_LIMIT'] = '1' + GlobalLimit.set_global_limit(1) tasks = 20 lock = Lock() - assert get_configured_limit() == 1 - def f(obj): locked = lock.acquire(False) # we should always get the lock because we're the only thread running