Clean up limit setting code and add reasonable input guards

Signed-off-by: Joffrey F <joffrey@docker.com>
This commit is contained in:
Joffrey F 2018-01-04 12:16:46 -08:00
parent d582ae0009
commit 8d3c7d4bce
4 changed files with 114 additions and 102 deletions

View File

@ -10,6 +10,7 @@ import six
from . import errors from . import errors
from . import verbose_proxy from . import verbose_proxy
from .. import config from .. import config
from .. import parallel
from ..config.environment import Environment from ..config.environment import Environment
from ..const import API_VERSIONS from ..const import API_VERSIONS
from ..project import Project from ..project import Project
@ -23,6 +24,8 @@ log = logging.getLogger(__name__)
def project_from_options(project_dir, options): def project_from_options(project_dir, options):
environment = Environment.from_env_file(project_dir) environment = Environment.from_env_file(project_dir)
set_parallel_limit(environment)
host = options.get('--host') host = options.get('--host')
if host is not None: if host is not None:
host = host.lstrip('=') host = host.lstrip('=')
@ -38,6 +41,22 @@ def project_from_options(project_dir, options):
) )
def set_parallel_limit(environment):
parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT')
if parallel_limit:
try:
parallel_limit = int(parallel_limit)
except ValueError:
raise errors.UserError(
'COMPOSE_PARALLEL_LIMIT must be an integer (found: "{}")'.format(
environment.get('COMPOSE_PARALLEL_LIMIT')
)
)
if parallel_limit <= 1:
raise errors.UserError('COMPOSE_PARALLEL_LIMIT can not be less than 2')
parallel.GlobalLimit.set_global_limit(parallel_limit)
def get_config_from_options(base_dir, options): def get_config_from_options(base_dir, options):
environment = Environment.from_env_file(base_dir) environment = Environment.from_env_file(base_dir)
config_path = get_config_path_from_options( config_path = get_config_path_from_options(
@ -99,13 +118,8 @@ def get_project(project_dir, config_path=None, project_name=None, verbose=False,
host=host, environment=environment host=host, environment=environment
) )
global_parallel_limit = environment.get('COMPOSE_PARALLEL_LIMIT')
if global_parallel_limit:
global_parallel_limit = int(global_parallel_limit)
with errors.handle_connection_errors(client): with errors.handle_connection_errors(client):
return Project.from_config(project_name, config_data, client, return Project.from_config(project_name, config_data, client)
global_parallel_limit=global_parallel_limit)
def get_project_name(working_dir, project_name=None, environment=None): def get_project_name(working_dir, project_name=None, environment=None):

View File

@ -35,9 +35,10 @@ class GlobalLimit(object):
global_limiter = Semaphore(PARALLEL_LIMIT) global_limiter = Semaphore(PARALLEL_LIMIT)
@classmethod @classmethod
def set_global_limit(cls, value=None): def set_global_limit(cls, value):
if value is not None: if value is None:
cls.global_limiter = Semaphore(value) value = PARALLEL_LIMIT
cls.global_limiter = Semaphore(value)
def parallel_execute(objects, func, get_name, msg, get_deps=None, limit=None, parent_objects=None): def parallel_execute(objects, func, get_name, msg, get_deps=None, limit=None, parent_objects=None):

View File

@ -61,15 +61,13 @@ class Project(object):
""" """
A collection of services. A collection of services.
""" """
def __init__(self, name, services, client, networks=None, volumes=None, config_version=None, def __init__(self, name, services, client, networks=None, volumes=None, config_version=None):
parallel_limit=None):
self.name = name self.name = name
self.services = services self.services = services
self.client = client self.client = client
self.volumes = volumes or ProjectVolumes({}) self.volumes = volumes or ProjectVolumes({})
self.networks = networks or ProjectNetworks({}, False) self.networks = networks or ProjectNetworks({}, False)
self.config_version = config_version self.config_version = config_version
parallel.GlobalLimit.set_global_limit(value=parallel_limit)
def labels(self, one_off=OneOffFilter.exclude): def labels(self, one_off=OneOffFilter.exclude):
labels = ['{0}={1}'.format(LABEL_PROJECT, self.name)] labels = ['{0}={1}'.format(LABEL_PROJECT, self.name)]
@ -78,7 +76,7 @@ class Project(object):
return labels return labels
@classmethod @classmethod
def from_config(cls, name, config_data, client, global_parallel_limit=None): def from_config(cls, name, config_data, client):
""" """
Construct a Project from a config.Config object. Construct a Project from a config.Config object.
""" """
@ -89,8 +87,7 @@ class Project(object):
networks, networks,
use_networking) use_networking)
volumes = ProjectVolumes.from_config(name, config_data, client) volumes = ProjectVolumes.from_config(name, config_data, client)
project = cls(name, [], client, project_networks, volumes, config_data.version, project = cls(name, [], client, project_networks, volumes, config_data.version)
parallel_limit=global_parallel_limit)
for service_dict in config_data.services: for service_dict in config_data.services:
service_dict = dict(service_dict) service_dict = dict(service_dict)

View File

@ -1,6 +1,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
import unittest
from threading import Lock from threading import Lock
import six import six
@ -32,114 +33,113 @@ def get_deps(obj):
return [(dep, None) for dep in deps[obj]] return [(dep, None) for dep in deps[obj]]
def test_parallel_execute(): class ParallelTest(unittest.TestCase):
results, errors = parallel_execute(
objects=[1, 2, 3, 4, 5],
func=lambda x: x * 2,
get_name=six.text_type,
msg="Doubling",
)
assert sorted(results) == [2, 4, 6, 8, 10] def test_parallel_execute(self):
assert errors == {} results, errors = parallel_execute(
objects=[1, 2, 3, 4, 5],
func=lambda x: x * 2,
get_name=six.text_type,
msg="Doubling",
)
assert sorted(results) == [2, 4, 6, 8, 10]
assert errors == {}
def test_parallel_execute_with_limit(): def test_parallel_execute_with_limit(self):
limit = 1 limit = 1
tasks = 20 tasks = 20
lock = Lock() lock = Lock()
def f(obj): def f(obj):
locked = lock.acquire(False) locked = lock.acquire(False)
# we should always get the lock because we're the only thread running # we should always get the lock because we're the only thread running
assert locked assert locked
lock.release() lock.release()
return None return None
results, errors = parallel_execute( results, errors = parallel_execute(
objects=list(range(tasks)), objects=list(range(tasks)),
func=f, func=f,
get_name=six.text_type, get_name=six.text_type,
msg="Testing", msg="Testing",
limit=limit, limit=limit,
) )
assert results == tasks * [None] assert results == tasks * [None]
assert errors == {} assert errors == {}
def test_parallel_execute_with_global_limit(self):
GlobalLimit.set_global_limit(1)
self.addCleanup(GlobalLimit.set_global_limit, None)
tasks = 20
lock = Lock()
def test_parallel_execute_with_global_limit(): def f(obj):
GlobalLimit.set_global_limit(1) locked = lock.acquire(False)
tasks = 20 # we should always get the lock because we're the only thread running
lock = Lock() assert locked
lock.release()
return None
def f(obj): results, errors = parallel_execute(
locked = lock.acquire(False) objects=list(range(tasks)),
# we should always get the lock because we're the only thread running func=f,
assert locked get_name=six.text_type,
lock.release() msg="Testing",
return None )
results, errors = parallel_execute( assert results == tasks * [None]
objects=list(range(tasks)), assert errors == {}
func=f,
get_name=six.text_type,
msg="Testing",
)
assert results == tasks * [None] def test_parallel_execute_with_deps(self):
assert errors == {} log = []
def process(x):
log.append(x)
def test_parallel_execute_with_deps(): parallel_execute(
log = [] objects=objects,
func=process,
get_name=lambda obj: obj,
msg="Processing",
get_deps=get_deps,
)
def process(x): assert sorted(log) == sorted(objects)
log.append(x)
parallel_execute( assert log.index(data_volume) < log.index(db)
objects=objects, assert log.index(db) < log.index(web)
func=process, assert log.index(cache) < log.index(web)
get_name=lambda obj: obj,
msg="Processing",
get_deps=get_deps,
)
assert sorted(log) == sorted(objects) def test_parallel_execute_with_upstream_errors(self):
log = []
assert log.index(data_volume) < log.index(db) def process(x):
assert log.index(db) < log.index(web) if x is data_volume:
assert log.index(cache) < log.index(web) raise APIError(None, None, "Something went wrong")
log.append(x)
parallel_execute(
objects=objects,
func=process,
get_name=lambda obj: obj,
msg="Processing",
get_deps=get_deps,
)
def test_parallel_execute_with_upstream_errors(): assert log == [cache]
log = []
def process(x): events = [
if x is data_volume: (obj, result, type(exception))
raise APIError(None, None, "Something went wrong") for obj, result, exception
log.append(x) in parallel_execute_iter(objects, process, get_deps, None)
]
parallel_execute( assert (cache, None, type(None)) in events
objects=objects, assert (data_volume, None, APIError) in events
func=process, assert (db, None, UpstreamError) in events
get_name=lambda obj: obj, assert (web, None, UpstreamError) in events
msg="Processing",
get_deps=get_deps,
)
assert log == [cache]
events = [
(obj, result, type(exception))
for obj, result, exception
in parallel_execute_iter(objects, process, get_deps, None)
]
assert (cache, None, type(None)) in events
assert (data_volume, None, APIError) in events
assert (db, None, UpstreamError) in events
assert (web, None, UpstreamError) in events
def test_parallel_execute_alignment(capsys): def test_parallel_execute_alignment(capsys):