mirror of https://github.com/docker/compose.git
Clean up limit setting code and add reasonable input guards
Signed-off-by: Joffrey F <joffrey@docker.com>
This commit is contained in:
parent
d582ae0009
commit
8d3c7d4bce
|
@ -10,6 +10,7 @@ import six
|
||||||
from . import errors
|
from . import errors
|
||||||
from . import verbose_proxy
|
from . import verbose_proxy
|
||||||
from .. import config
|
from .. import config
|
||||||
|
from .. import parallel
|
||||||
from ..config.environment import Environment
|
from ..config.environment import Environment
|
||||||
from ..const import API_VERSIONS
|
from ..const import API_VERSIONS
|
||||||
from ..project import Project
|
from ..project import Project
|
||||||
|
@ -23,6 +24,8 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
def project_from_options(project_dir, options):
|
def project_from_options(project_dir, options):
|
||||||
environment = Environment.from_env_file(project_dir)
|
environment = Environment.from_env_file(project_dir)
|
||||||
|
set_parallel_limit(environment)
|
||||||
|
|
||||||
host = options.get('--host')
|
host = options.get('--host')
|
||||||
if host is not None:
|
if host is not None:
|
||||||
host = host.lstrip('=')
|
host = host.lstrip('=')
|
||||||
|
@ -38,6 +41,22 @@ def project_from_options(project_dir, options):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_parallel_limit(environment):
|
||||||
|
parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT')
|
||||||
|
if parallel_limit:
|
||||||
|
try:
|
||||||
|
parallel_limit = int(parallel_limit)
|
||||||
|
except ValueError:
|
||||||
|
raise errors.UserError(
|
||||||
|
'COMPOSE_PARALLEL_LIMIT must be an integer (found: "{}")'.format(
|
||||||
|
environment.get('COMPOSE_PARALLEL_LIMIT')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if parallel_limit <= 1:
|
||||||
|
raise errors.UserError('COMPOSE_PARALLEL_LIMIT can not be less than 2')
|
||||||
|
parallel.GlobalLimit.set_global_limit(parallel_limit)
|
||||||
|
|
||||||
|
|
||||||
def get_config_from_options(base_dir, options):
|
def get_config_from_options(base_dir, options):
|
||||||
environment = Environment.from_env_file(base_dir)
|
environment = Environment.from_env_file(base_dir)
|
||||||
config_path = get_config_path_from_options(
|
config_path = get_config_path_from_options(
|
||||||
|
@ -99,13 +118,8 @@ def get_project(project_dir, config_path=None, project_name=None, verbose=False,
|
||||||
host=host, environment=environment
|
host=host, environment=environment
|
||||||
)
|
)
|
||||||
|
|
||||||
global_parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT')
|
|
||||||
if global_parallel_limit:
|
|
||||||
global_parallel_limit = int(global_parallel_limit)
|
|
||||||
|
|
||||||
with errors.handle_connection_errors(client):
|
with errors.handle_connection_errors(client):
|
||||||
return Project.from_config(project_name, config_data, client,
|
return Project.from_config(project_name, config_data, client)
|
||||||
global_parallel_limit=global_parallel_limit)
|
|
||||||
|
|
||||||
|
|
||||||
def get_project_name(working_dir, project_name=None, environment=None):
|
def get_project_name(working_dir, project_name=None, environment=None):
|
||||||
|
|
|
@ -35,8 +35,9 @@ class GlobalLimit(object):
|
||||||
global_limiter = Semaphore(PARALLEL_LIMIT)
|
global_limiter = Semaphore(PARALLEL_LIMIT)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_global_limit(cls, value=None):
|
def set_global_limit(cls, value):
|
||||||
if value is not None:
|
if value is None:
|
||||||
|
value = PARALLEL_LIMIT
|
||||||
cls.global_limiter = Semaphore(value)
|
cls.global_limiter = Semaphore(value)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -61,15 +61,13 @@ class Project(object):
|
||||||
"""
|
"""
|
||||||
A collection of services.
|
A collection of services.
|
||||||
"""
|
"""
|
||||||
def __init__(self, name, services, client, networks=None, volumes=None, config_version=None,
|
def __init__(self, name, services, client, networks=None, volumes=None, config_version=None):
|
||||||
parallel_limit=None):
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.services = services
|
self.services = services
|
||||||
self.client = client
|
self.client = client
|
||||||
self.volumes = volumes or ProjectVolumes({})
|
self.volumes = volumes or ProjectVolumes({})
|
||||||
self.networks = networks or ProjectNetworks({}, False)
|
self.networks = networks or ProjectNetworks({}, False)
|
||||||
self.config_version = config_version
|
self.config_version = config_version
|
||||||
parallel.GlobalLimit.set_global_limit(value=parallel_limit)
|
|
||||||
|
|
||||||
def labels(self, one_off=OneOffFilter.exclude):
|
def labels(self, one_off=OneOffFilter.exclude):
|
||||||
labels = ['{0}={1}'.format(LABEL_PROJECT, self.name)]
|
labels = ['{0}={1}'.format(LABEL_PROJECT, self.name)]
|
||||||
|
@ -78,7 +76,7 @@ class Project(object):
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, name, config_data, client, global_parallel_limit=None):
|
def from_config(cls, name, config_data, client):
|
||||||
"""
|
"""
|
||||||
Construct a Project from a config.Config object.
|
Construct a Project from a config.Config object.
|
||||||
"""
|
"""
|
||||||
|
@ -89,8 +87,7 @@ class Project(object):
|
||||||
networks,
|
networks,
|
||||||
use_networking)
|
use_networking)
|
||||||
volumes = ProjectVolumes.from_config(name, config_data, client)
|
volumes = ProjectVolumes.from_config(name, config_data, client)
|
||||||
project = cls(name, [], client, project_networks, volumes, config_data.version,
|
project = cls(name, [], client, project_networks, volumes, config_data.version)
|
||||||
parallel_limit=global_parallel_limit)
|
|
||||||
|
|
||||||
for service_dict in config_data.services:
|
for service_dict in config_data.services:
|
||||||
service_dict = dict(service_dict)
|
service_dict = dict(service_dict)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import unittest
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
@ -32,7 +33,9 @@ def get_deps(obj):
|
||||||
return [(dep, None) for dep in deps[obj]]
|
return [(dep, None) for dep in deps[obj]]
|
||||||
|
|
||||||
|
|
||||||
def test_parallel_execute():
|
class ParallelTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_parallel_execute(self):
|
||||||
results, errors = parallel_execute(
|
results, errors = parallel_execute(
|
||||||
objects=[1, 2, 3, 4, 5],
|
objects=[1, 2, 3, 4, 5],
|
||||||
func=lambda x: x * 2,
|
func=lambda x: x * 2,
|
||||||
|
@ -43,8 +46,7 @@ def test_parallel_execute():
|
||||||
assert sorted(results) == [2, 4, 6, 8, 10]
|
assert sorted(results) == [2, 4, 6, 8, 10]
|
||||||
assert errors == {}
|
assert errors == {}
|
||||||
|
|
||||||
|
def test_parallel_execute_with_limit(self):
|
||||||
def test_parallel_execute_with_limit():
|
|
||||||
limit = 1
|
limit = 1
|
||||||
tasks = 20
|
tasks = 20
|
||||||
lock = Lock()
|
lock = Lock()
|
||||||
|
@ -67,9 +69,9 @@ def test_parallel_execute_with_limit():
|
||||||
assert results == tasks * [None]
|
assert results == tasks * [None]
|
||||||
assert errors == {}
|
assert errors == {}
|
||||||
|
|
||||||
|
def test_parallel_execute_with_global_limit(self):
|
||||||
def test_parallel_execute_with_global_limit():
|
|
||||||
GlobalLimit.set_global_limit(1)
|
GlobalLimit.set_global_limit(1)
|
||||||
|
self.addCleanup(GlobalLimit.set_global_limit, None)
|
||||||
tasks = 20
|
tasks = 20
|
||||||
lock = Lock()
|
lock = Lock()
|
||||||
|
|
||||||
|
@ -90,8 +92,7 @@ def test_parallel_execute_with_global_limit():
|
||||||
assert results == tasks * [None]
|
assert results == tasks * [None]
|
||||||
assert errors == {}
|
assert errors == {}
|
||||||
|
|
||||||
|
def test_parallel_execute_with_deps(self):
|
||||||
def test_parallel_execute_with_deps():
|
|
||||||
log = []
|
log = []
|
||||||
|
|
||||||
def process(x):
|
def process(x):
|
||||||
|
@ -111,8 +112,7 @@ def test_parallel_execute_with_deps():
|
||||||
assert log.index(db) < log.index(web)
|
assert log.index(db) < log.index(web)
|
||||||
assert log.index(cache) < log.index(web)
|
assert log.index(cache) < log.index(web)
|
||||||
|
|
||||||
|
def test_parallel_execute_with_upstream_errors(self):
|
||||||
def test_parallel_execute_with_upstream_errors():
|
|
||||||
log = []
|
log = []
|
||||||
|
|
||||||
def process(x):
|
def process(x):
|
||||||
|
|
Loading…
Reference in New Issue