Allow parallel limit to be set in env file.

Signed-off-by: Ashlie Martinez <ashmrtn@utexas.edu>
This commit is contained in:
Ashlie Martinez 2017-12-16 18:11:55 -06:00
parent dc6b464751
commit acf76c15a2
4 changed files with 24 additions and 21 deletions

View File

@ -99,8 +99,13 @@ def get_project(project_dir, config_path=None, project_name=None, verbose=False,
host=host, environment=environment host=host, environment=environment
) )
global_parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT')
if global_parallel_limit:
global_parallel_limit = int(global_parallel_limit)
with errors.handle_connection_errors(client): with errors.handle_connection_errors(client):
return Project.from_config(project_name, config_data, client) return Project.from_config(project_name, config_data, client,
global_parallel_limit=global_parallel_limit)
def get_project_name(working_dir, project_name=None, environment=None): def get_project_name(working_dir, project_name=None, environment=None):

View File

@ -15,7 +15,6 @@ from six.moves.queue import Queue
from compose.cli.colors import green from compose.cli.colors import green
from compose.cli.colors import red from compose.cli.colors import red
from compose.cli.signals import ShutdownException from compose.cli.signals import ShutdownException
from compose.config.environment import Environment
from compose.const import PARALLEL_LIMIT from compose.const import PARALLEL_LIMIT
from compose.errors import HealthCheckFailed from compose.errors import HealthCheckFailed
from compose.errors import NoHealthCheckConfigured from compose.errors import NoHealthCheckConfigured
@ -28,16 +27,17 @@ log = logging.getLogger(__name__)
STOP = object() STOP = object()
def get_configured_limit(): class GlobalLimit(object):
limit = Environment.from_command_line({'COMPOSE_PARALLEL_LIMIT': None})['COMPOSE_PARALLEL_LIMIT'] """Simple class to hold a global semaphore limiter for a project. This class
if limit: should be treated as a singleton that is instantiated when the project is.
limit = int(limit) """
else:
limit = PARALLEL_LIMIT
return limit
global_limiter = Semaphore(PARALLEL_LIMIT)
global_limiter = Semaphore(get_configured_limit()) @classmethod
def set_global_limit(cls, value=None):
if value is not None:
cls.global_limiter = Semaphore(value)
def parallel_execute(objects, func, get_name, msg, get_deps=None, limit=None, parent_objects=None): def parallel_execute(objects, func, get_name, msg, get_deps=None, limit=None, parent_objects=None):
@ -187,7 +187,7 @@ def producer(obj, func, results, limiter):
The entry point for a producer thread which runs func on a single object. The entry point for a producer thread which runs func on a single object.
Places a tuple on the results queue once func has either returned or raised. Places a tuple on the results queue once func has either returned or raised.
""" """
with limiter, global_limiter: with limiter, GlobalLimit.global_limiter:
try: try:
result = func(obj) result = func(obj)
results.put((obj, result, None)) results.put((obj, result, None))

View File

@ -61,13 +61,15 @@ class Project(object):
""" """
A collection of services. A collection of services.
""" """
def __init__(self, name, services, client, networks=None, volumes=None, config_version=None): def __init__(self, name, services, client, networks=None, volumes=None, config_version=None,
parallel_limit=None):
self.name = name self.name = name
self.services = services self.services = services
self.client = client self.client = client
self.volumes = volumes or ProjectVolumes({}) self.volumes = volumes or ProjectVolumes({})
self.networks = networks or ProjectNetworks({}, False) self.networks = networks or ProjectNetworks({}, False)
self.config_version = config_version self.config_version = config_version
parallel.GlobalLimit.set_global_limit(value=parallel_limit)
def labels(self, one_off=OneOffFilter.exclude): def labels(self, one_off=OneOffFilter.exclude):
labels = ['{0}={1}'.format(LABEL_PROJECT, self.name)] labels = ['{0}={1}'.format(LABEL_PROJECT, self.name)]
@ -76,7 +78,7 @@ class Project(object):
return labels return labels
@classmethod @classmethod
def from_config(cls, name, config_data, client): def from_config(cls, name, config_data, client, global_parallel_limit=None):
""" """
Construct a Project from a config.Config object. Construct a Project from a config.Config object.
""" """
@ -87,7 +89,8 @@ class Project(object):
networks, networks,
use_networking) use_networking)
volumes = ProjectVolumes.from_config(name, config_data, client) volumes = ProjectVolumes.from_config(name, config_data, client)
project = cls(name, [], client, project_networks, volumes, config_data.version) project = cls(name, [], client, project_networks, volumes, config_data.version,
parallel_limit=global_parallel_limit)
for service_dict in config_data.services: for service_dict in config_data.services:
service_dict = dict(service_dict) service_dict = dict(service_dict)

View File

@ -1,14 +1,12 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
import os
from threading import Lock from threading import Lock
import six import six
from docker.errors import APIError from docker.errors import APIError
from .. import mock from compose.parallel import GlobalLimit
from compose.parallel import get_configured_limit
from compose.parallel import parallel_execute from compose.parallel import parallel_execute
from compose.parallel import parallel_execute_iter from compose.parallel import parallel_execute_iter
from compose.parallel import ParallelStreamWriter from compose.parallel import ParallelStreamWriter
@ -70,14 +68,11 @@ def test_parallel_execute_with_limit():
assert errors == {} assert errors == {}
@mock.patch.dict(os.environ)
def test_parallel_execute_with_global_limit(): def test_parallel_execute_with_global_limit():
os.environ['COMPOSE_PARALLEL_LIMIT'] = '1' GlobalLimit.set_global_limit(1)
tasks = 20 tasks = 20
lock = Lock() lock = Lock()
assert get_configured_limit() == 1
def f(obj): def f(obj):
locked = lock.acquire(False) locked = lock.acquire(False)
# we should always get the lock because we're the only thread running # we should always get the lock because we're the only thread running