mirror of https://github.com/docker/compose.git
Merge pull request #3108 from dnephin/inplace_var_defaults
Support inline default values for interpolation
This commit is contained in:
commit
eac01a7cd5
|
@ -413,31 +413,35 @@ def load_services(config_details, config_file):
|
|||
return build_services(service_config)
|
||||
|
||||
|
||||
def interpolate_config_section(filename, config, section, environment):
|
||||
validate_config_section(filename, config, section)
|
||||
return interpolate_environment_variables(config, section, environment)
|
||||
def interpolate_config_section(config_file, config, section, environment):
|
||||
validate_config_section(config_file.filename, config, section)
|
||||
return interpolate_environment_variables(
|
||||
config_file.version,
|
||||
config,
|
||||
section,
|
||||
environment)
|
||||
|
||||
|
||||
def process_config_file(config_file, environment, service_name=None):
|
||||
services = interpolate_config_section(
|
||||
config_file.filename,
|
||||
config_file,
|
||||
config_file.get_service_dicts(),
|
||||
'service',
|
||||
environment,)
|
||||
environment)
|
||||
|
||||
if config_file.version in (V2_0, V2_1):
|
||||
processed_config = dict(config_file.config)
|
||||
processed_config['services'] = services
|
||||
processed_config['volumes'] = interpolate_config_section(
|
||||
config_file.filename,
|
||||
config_file,
|
||||
config_file.get_volumes(),
|
||||
'volume',
|
||||
environment,)
|
||||
environment)
|
||||
processed_config['networks'] = interpolate_config_section(
|
||||
config_file.filename,
|
||||
config_file,
|
||||
config_file.get_networks(),
|
||||
'network',
|
||||
environment,)
|
||||
environment)
|
||||
|
||||
if config_file.version == V1:
|
||||
processed_config = services
|
||||
|
|
|
@ -7,14 +7,35 @@ from string import Template
|
|||
import six
|
||||
|
||||
from .errors import ConfigurationError
|
||||
from compose.const import COMPOSEFILE_V1 as V1
|
||||
from compose.const import COMPOSEFILE_V2_0 as V2_0
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def interpolate_environment_variables(config, section, environment):
|
||||
class Interpolator(object):
|
||||
|
||||
def __init__(self, templater, mapping):
|
||||
self.templater = templater
|
||||
self.mapping = mapping
|
||||
|
||||
def interpolate(self, string):
|
||||
try:
|
||||
return self.templater(string).substitute(self.mapping)
|
||||
except ValueError:
|
||||
raise InvalidInterpolation(string)
|
||||
|
||||
|
||||
def interpolate_environment_variables(version, config, section, environment):
|
||||
if version in (V2_0, V1):
|
||||
interpolator = Interpolator(Template, environment)
|
||||
else:
|
||||
interpolator = Interpolator(TemplateWithDefaults, environment)
|
||||
|
||||
def process_item(name, config_dict):
|
||||
return dict(
|
||||
(key, interpolate_value(name, key, val, section, environment))
|
||||
(key, interpolate_value(name, key, val, section, interpolator))
|
||||
for key, val in (config_dict or {}).items()
|
||||
)
|
||||
|
||||
|
@ -24,9 +45,9 @@ def interpolate_environment_variables(config, section, environment):
|
|||
)
|
||||
|
||||
|
||||
def interpolate_value(name, config_key, value, section, mapping):
|
||||
def interpolate_value(name, config_key, value, section, interpolator):
|
||||
try:
|
||||
return recursive_interpolate(value, mapping)
|
||||
return recursive_interpolate(value, interpolator)
|
||||
except InvalidInterpolation as e:
|
||||
raise ConfigurationError(
|
||||
'Invalid interpolation format for "{config_key}" option '
|
||||
|
@ -37,25 +58,44 @@ def interpolate_value(name, config_key, value, section, mapping):
|
|||
string=e.string))
|
||||
|
||||
|
||||
def recursive_interpolate(obj, mapping):
|
||||
def recursive_interpolate(obj, interpolator):
|
||||
if isinstance(obj, six.string_types):
|
||||
return interpolate(obj, mapping)
|
||||
elif isinstance(obj, dict):
|
||||
return interpolator.interpolate(obj)
|
||||
if isinstance(obj, dict):
|
||||
return dict(
|
||||
(key, recursive_interpolate(val, mapping))
|
||||
(key, recursive_interpolate(val, interpolator))
|
||||
for (key, val) in obj.items()
|
||||
)
|
||||
elif isinstance(obj, list):
|
||||
return [recursive_interpolate(val, mapping) for val in obj]
|
||||
else:
|
||||
return obj
|
||||
if isinstance(obj, list):
|
||||
return [recursive_interpolate(val, interpolator) for val in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def interpolate(string, mapping):
|
||||
try:
|
||||
return Template(string).substitute(mapping)
|
||||
except ValueError:
|
||||
raise InvalidInterpolation(string)
|
||||
class TemplateWithDefaults(Template):
|
||||
idpattern = r'[_a-z][_a-z0-9]*(?::?-[_a-z0-9]+)?'
|
||||
|
||||
# Modified from python2.7/string.py
|
||||
def substitute(self, mapping):
|
||||
# Helper function for .sub()
|
||||
def convert(mo):
|
||||
# Check the most common path first.
|
||||
named = mo.group('named') or mo.group('braced')
|
||||
if named is not None:
|
||||
if ':-' in named:
|
||||
var, _, default = named.partition(':-')
|
||||
return mapping.get(var) or default
|
||||
if '-' in named:
|
||||
var, _, default = named.partition('-')
|
||||
return mapping.get(var, default)
|
||||
val = mapping[named]
|
||||
return '%s' % (val,)
|
||||
if mo.group('escaped') is not None:
|
||||
return self.delimiter
|
||||
if mo.group('invalid') is not None:
|
||||
self._invalid(mo)
|
||||
raise ValueError('Unrecognized named group in pattern',
|
||||
self.pattern)
|
||||
return self.pattern.sub(convert, self.template)
|
||||
|
||||
|
||||
class InvalidInterpolation(Exception):
|
||||
|
|
|
@ -1,21 +1,28 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
|
||||
import mock
|
||||
import pytest
|
||||
|
||||
from compose.config.environment import Environment
|
||||
from compose.config.interpolation import interpolate_environment_variables
|
||||
from compose.config.interpolation import Interpolator
|
||||
from compose.config.interpolation import InvalidInterpolation
|
||||
from compose.config.interpolation import TemplateWithDefaults
|
||||
|
||||
|
||||
@pytest.yield_fixture
|
||||
@pytest.fixture
|
||||
def mock_env():
|
||||
with mock.patch.dict(os.environ):
|
||||
os.environ['USER'] = 'jenny'
|
||||
os.environ['FOO'] = 'bar'
|
||||
yield
|
||||
return Environment({'USER': 'jenny', 'FOO': 'bar'})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def variable_mapping():
|
||||
return Environment({'FOO': 'first', 'BAR': ''})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def defaults_interpolator(variable_mapping):
|
||||
return Interpolator(TemplateWithDefaults, variable_mapping).interpolate
|
||||
|
||||
|
||||
def test_interpolate_environment_variables_in_services(mock_env):
|
||||
|
@ -43,9 +50,8 @@ def test_interpolate_environment_variables_in_services(mock_env):
|
|||
}
|
||||
}
|
||||
}
|
||||
assert interpolate_environment_variables(
|
||||
services, 'service', Environment.from_env_file(None)
|
||||
) == expected
|
||||
value = interpolate_environment_variables("2.0", services, 'service', mock_env)
|
||||
assert value == expected
|
||||
|
||||
|
||||
def test_interpolate_environment_variables_in_volumes(mock_env):
|
||||
|
@ -69,6 +75,46 @@ def test_interpolate_environment_variables_in_volumes(mock_env):
|
|||
},
|
||||
'other': {},
|
||||
}
|
||||
assert interpolate_environment_variables(
|
||||
volumes, 'volume', Environment.from_env_file(None)
|
||||
) == expected
|
||||
value = interpolate_environment_variables("2.0", volumes, 'volume', mock_env)
|
||||
assert value == expected
|
||||
|
||||
|
||||
def test_escaped_interpolation(defaults_interpolator):
|
||||
assert defaults_interpolator('$${foo}') == '${foo}'
|
||||
|
||||
|
||||
def test_invalid_interpolation(defaults_interpolator):
|
||||
with pytest.raises(InvalidInterpolation):
|
||||
defaults_interpolator('${')
|
||||
with pytest.raises(InvalidInterpolation):
|
||||
defaults_interpolator('$}')
|
||||
with pytest.raises(InvalidInterpolation):
|
||||
defaults_interpolator('${}')
|
||||
with pytest.raises(InvalidInterpolation):
|
||||
defaults_interpolator('${ }')
|
||||
with pytest.raises(InvalidInterpolation):
|
||||
defaults_interpolator('${ foo}')
|
||||
with pytest.raises(InvalidInterpolation):
|
||||
defaults_interpolator('${foo }')
|
||||
with pytest.raises(InvalidInterpolation):
|
||||
defaults_interpolator('${foo!}')
|
||||
|
||||
|
||||
def test_interpolate_missing_no_default(defaults_interpolator):
|
||||
assert defaults_interpolator("This ${missing} var") == "This var"
|
||||
assert defaults_interpolator("This ${BAR} var") == "This var"
|
||||
|
||||
|
||||
def test_interpolate_with_value(defaults_interpolator):
|
||||
assert defaults_interpolator("This $FOO var") == "This first var"
|
||||
assert defaults_interpolator("This ${FOO} var") == "This first var"
|
||||
|
||||
|
||||
def test_interpolate_missing_with_default(defaults_interpolator):
|
||||
assert defaults_interpolator("ok ${missing:-def}") == "ok def"
|
||||
assert defaults_interpolator("ok ${missing-def}") == "ok def"
|
||||
|
||||
|
||||
def test_interpolate_with_empty_and_default_value(defaults_interpolator):
|
||||
assert defaults_interpolator("ok ${BAR:-def}") == "ok def"
|
||||
assert defaults_interpolator("ok ${BAR-def}") == "ok "
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
|
||||
from compose.config.environment import Environment as bddict
|
||||
from compose.config.interpolation import interpolate
|
||||
from compose.config.interpolation import InvalidInterpolation
|
||||
|
||||
|
||||
class InterpolationTest(unittest.TestCase):
|
||||
def test_valid_interpolations(self):
|
||||
self.assertEqual(interpolate('$foo', bddict(foo='hi')), 'hi')
|
||||
self.assertEqual(interpolate('${foo}', bddict(foo='hi')), 'hi')
|
||||
|
||||
self.assertEqual(interpolate('${subject} love you', bddict(subject='i')), 'i love you')
|
||||
self.assertEqual(interpolate('i ${verb} you', bddict(verb='love')), 'i love you')
|
||||
self.assertEqual(interpolate('i love ${object}', bddict(object='you')), 'i love you')
|
||||
|
||||
def test_empty_value(self):
|
||||
self.assertEqual(interpolate('${foo}', bddict(foo='')), '')
|
||||
|
||||
def test_unset_value(self):
|
||||
self.assertEqual(interpolate('${foo}', bddict()), '')
|
||||
|
||||
def test_escaped_interpolation(self):
|
||||
self.assertEqual(interpolate('$${foo}', bddict(foo='hi')), '${foo}')
|
||||
|
||||
def test_invalid_strings(self):
|
||||
self.assertRaises(InvalidInterpolation, lambda: interpolate('${', bddict()))
|
||||
self.assertRaises(InvalidInterpolation, lambda: interpolate('$}', bddict()))
|
||||
self.assertRaises(InvalidInterpolation, lambda: interpolate('${}', bddict()))
|
||||
self.assertRaises(InvalidInterpolation, lambda: interpolate('${ }', bddict()))
|
||||
self.assertRaises(InvalidInterpolation, lambda: interpolate('${ foo}', bddict()))
|
||||
self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo }', bddict()))
|
||||
self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo!}', bddict()))
|
Loading…
Reference in New Issue