Merge pull request #3108 from dnephin/inplace_var_defaults

Support inline default values for interpolation
This commit is contained in:
Joffrey F 2016-10-05 14:22:02 -07:00 committed by GitHub
commit eac01a7cd5
4 changed files with 130 additions and 76 deletions

View File

@ -413,31 +413,35 @@ def load_services(config_details, config_file):
return build_services(service_config)
def interpolate_config_section(filename, config, section, environment):
validate_config_section(filename, config, section)
return interpolate_environment_variables(config, section, environment)
def interpolate_config_section(config_file, config, section, environment):
validate_config_section(config_file.filename, config, section)
return interpolate_environment_variables(
config_file.version,
config,
section,
environment)
def process_config_file(config_file, environment, service_name=None):
services = interpolate_config_section(
config_file.filename,
config_file,
config_file.get_service_dicts(),
'service',
environment,)
environment)
if config_file.version in (V2_0, V2_1):
processed_config = dict(config_file.config)
processed_config['services'] = services
processed_config['volumes'] = interpolate_config_section(
config_file.filename,
config_file,
config_file.get_volumes(),
'volume',
environment,)
environment)
processed_config['networks'] = interpolate_config_section(
config_file.filename,
config_file,
config_file.get_networks(),
'network',
environment,)
environment)
if config_file.version == V1:
processed_config = services

View File

@ -7,14 +7,35 @@ from string import Template
import six
from .errors import ConfigurationError
from compose.const import COMPOSEFILE_V1 as V1
from compose.const import COMPOSEFILE_V2_0 as V2_0
log = logging.getLogger(__name__)
def interpolate_environment_variables(config, section, environment):
class Interpolator(object):
def __init__(self, templater, mapping):
self.templater = templater
self.mapping = mapping
def interpolate(self, string):
try:
return self.templater(string).substitute(self.mapping)
except ValueError:
raise InvalidInterpolation(string)
def interpolate_environment_variables(version, config, section, environment):
if version in (V2_0, V1):
interpolator = Interpolator(Template, environment)
else:
interpolator = Interpolator(TemplateWithDefaults, environment)
def process_item(name, config_dict):
return dict(
(key, interpolate_value(name, key, val, section, environment))
(key, interpolate_value(name, key, val, section, interpolator))
for key, val in (config_dict or {}).items()
)
@ -24,9 +45,9 @@ def interpolate_environment_variables(config, section, environment):
)
def interpolate_value(name, config_key, value, section, mapping):
def interpolate_value(name, config_key, value, section, interpolator):
try:
return recursive_interpolate(value, mapping)
return recursive_interpolate(value, interpolator)
except InvalidInterpolation as e:
raise ConfigurationError(
'Invalid interpolation format for "{config_key}" option '
@ -37,25 +58,44 @@ def interpolate_value(name, config_key, value, section, mapping):
string=e.string))
def recursive_interpolate(obj, mapping):
def recursive_interpolate(obj, interpolator):
if isinstance(obj, six.string_types):
return interpolate(obj, mapping)
elif isinstance(obj, dict):
return interpolator.interpolate(obj)
if isinstance(obj, dict):
return dict(
(key, recursive_interpolate(val, mapping))
(key, recursive_interpolate(val, interpolator))
for (key, val) in obj.items()
)
elif isinstance(obj, list):
return [recursive_interpolate(val, mapping) for val in obj]
else:
return obj
if isinstance(obj, list):
return [recursive_interpolate(val, interpolator) for val in obj]
return obj
def interpolate(string, mapping):
try:
return Template(string).substitute(mapping)
except ValueError:
raise InvalidInterpolation(string)
class TemplateWithDefaults(Template):
idpattern = r'[_a-z][_a-z0-9]*(?::?-[_a-z0-9]+)?'
# Modified from python2.7/string.py
def substitute(self, mapping):
# Helper function for .sub()
def convert(mo):
# Check the most common path first.
named = mo.group('named') or mo.group('braced')
if named is not None:
if ':-' in named:
var, _, default = named.partition(':-')
return mapping.get(var) or default
if '-' in named:
var, _, default = named.partition('-')
return mapping.get(var, default)
val = mapping[named]
return '%s' % (val,)
if mo.group('escaped') is not None:
return self.delimiter
if mo.group('invalid') is not None:
self._invalid(mo)
raise ValueError('Unrecognized named group in pattern',
self.pattern)
return self.pattern.sub(convert, self.template)
class InvalidInterpolation(Exception):

View File

@ -1,21 +1,28 @@
from __future__ import absolute_import
from __future__ import unicode_literals
import os
import mock
import pytest
from compose.config.environment import Environment
from compose.config.interpolation import interpolate_environment_variables
from compose.config.interpolation import Interpolator
from compose.config.interpolation import InvalidInterpolation
from compose.config.interpolation import TemplateWithDefaults
@pytest.yield_fixture
@pytest.fixture
def mock_env():
with mock.patch.dict(os.environ):
os.environ['USER'] = 'jenny'
os.environ['FOO'] = 'bar'
yield
return Environment({'USER': 'jenny', 'FOO': 'bar'})
@pytest.fixture
def variable_mapping():
return Environment({'FOO': 'first', 'BAR': ''})
@pytest.fixture
def defaults_interpolator(variable_mapping):
return Interpolator(TemplateWithDefaults, variable_mapping).interpolate
def test_interpolate_environment_variables_in_services(mock_env):
@ -43,9 +50,8 @@ def test_interpolate_environment_variables_in_services(mock_env):
}
}
}
assert interpolate_environment_variables(
services, 'service', Environment.from_env_file(None)
) == expected
value = interpolate_environment_variables("2.0", services, 'service', mock_env)
assert value == expected
def test_interpolate_environment_variables_in_volumes(mock_env):
@ -69,6 +75,46 @@ def test_interpolate_environment_variables_in_volumes(mock_env):
},
'other': {},
}
assert interpolate_environment_variables(
volumes, 'volume', Environment.from_env_file(None)
) == expected
value = interpolate_environment_variables("2.0", volumes, 'volume', mock_env)
assert value == expected
def test_escaped_interpolation(defaults_interpolator):
assert defaults_interpolator('$${foo}') == '${foo}'
def test_invalid_interpolation(defaults_interpolator):
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('$}')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${}')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${ }')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${ foo}')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${foo }')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${foo!}')
def test_interpolate_missing_no_default(defaults_interpolator):
assert defaults_interpolator("This ${missing} var") == "This var"
assert defaults_interpolator("This ${BAR} var") == "This var"
def test_interpolate_with_value(defaults_interpolator):
assert defaults_interpolator("This $FOO var") == "This first var"
assert defaults_interpolator("This ${FOO} var") == "This first var"
def test_interpolate_missing_with_default(defaults_interpolator):
assert defaults_interpolator("ok ${missing:-def}") == "ok def"
assert defaults_interpolator("ok ${missing-def}") == "ok def"
def test_interpolate_with_empty_and_default_value(defaults_interpolator):
assert defaults_interpolator("ok ${BAR:-def}") == "ok def"
assert defaults_interpolator("ok ${BAR-def}") == "ok "

View File

@ -1,36 +0,0 @@
from __future__ import absolute_import
from __future__ import unicode_literals
import unittest
from compose.config.environment import Environment as bddict
from compose.config.interpolation import interpolate
from compose.config.interpolation import InvalidInterpolation
class InterpolationTest(unittest.TestCase):
def test_valid_interpolations(self):
self.assertEqual(interpolate('$foo', bddict(foo='hi')), 'hi')
self.assertEqual(interpolate('${foo}', bddict(foo='hi')), 'hi')
self.assertEqual(interpolate('${subject} love you', bddict(subject='i')), 'i love you')
self.assertEqual(interpolate('i ${verb} you', bddict(verb='love')), 'i love you')
self.assertEqual(interpolate('i love ${object}', bddict(object='you')), 'i love you')
def test_empty_value(self):
self.assertEqual(interpolate('${foo}', bddict(foo='')), '')
def test_unset_value(self):
self.assertEqual(interpolate('${foo}', bddict()), '')
def test_escaped_interpolation(self):
self.assertEqual(interpolate('$${foo}', bddict(foo='hi')), '${foo}')
def test_invalid_strings(self):
self.assertRaises(InvalidInterpolation, lambda: interpolate('${', bddict()))
self.assertRaises(InvalidInterpolation, lambda: interpolate('$}', bddict()))
self.assertRaises(InvalidInterpolation, lambda: interpolate('${}', bddict()))
self.assertRaises(InvalidInterpolation, lambda: interpolate('${ }', bddict()))
self.assertRaises(InvalidInterpolation, lambda: interpolate('${ foo}', bddict()))
self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo }', bddict()))
self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo!}', bddict()))