Merge pull request #3108 from dnephin/inplace_var_defaults

Support inline default values for interpolation
This commit is contained in:
Joffrey F 2016-10-05 14:22:02 -07:00 committed by GitHub
commit eac01a7cd5
4 changed files with 130 additions and 76 deletions

View File

@ -413,31 +413,35 @@ def load_services(config_details, config_file):
return build_services(service_config) return build_services(service_config)
def interpolate_config_section(filename, config, section, environment): def interpolate_config_section(config_file, config, section, environment):
validate_config_section(filename, config, section) validate_config_section(config_file.filename, config, section)
return interpolate_environment_variables(config, section, environment) return interpolate_environment_variables(
config_file.version,
config,
section,
environment)
def process_config_file(config_file, environment, service_name=None): def process_config_file(config_file, environment, service_name=None):
services = interpolate_config_section( services = interpolate_config_section(
config_file.filename, config_file,
config_file.get_service_dicts(), config_file.get_service_dicts(),
'service', 'service',
environment,) environment)
if config_file.version in (V2_0, V2_1): if config_file.version in (V2_0, V2_1):
processed_config = dict(config_file.config) processed_config = dict(config_file.config)
processed_config['services'] = services processed_config['services'] = services
processed_config['volumes'] = interpolate_config_section( processed_config['volumes'] = interpolate_config_section(
config_file.filename, config_file,
config_file.get_volumes(), config_file.get_volumes(),
'volume', 'volume',
environment,) environment)
processed_config['networks'] = interpolate_config_section( processed_config['networks'] = interpolate_config_section(
config_file.filename, config_file,
config_file.get_networks(), config_file.get_networks(),
'network', 'network',
environment,) environment)
if config_file.version == V1: if config_file.version == V1:
processed_config = services processed_config = services

View File

@ -7,14 +7,35 @@ from string import Template
import six import six
from .errors import ConfigurationError from .errors import ConfigurationError
from compose.const import COMPOSEFILE_V1 as V1
from compose.const import COMPOSEFILE_V2_0 as V2_0
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def interpolate_environment_variables(config, section, environment): class Interpolator(object):
def __init__(self, templater, mapping):
self.templater = templater
self.mapping = mapping
def interpolate(self, string):
try:
return self.templater(string).substitute(self.mapping)
except ValueError:
raise InvalidInterpolation(string)
def interpolate_environment_variables(version, config, section, environment):
if version in (V2_0, V1):
interpolator = Interpolator(Template, environment)
else:
interpolator = Interpolator(TemplateWithDefaults, environment)
def process_item(name, config_dict): def process_item(name, config_dict):
return dict( return dict(
(key, interpolate_value(name, key, val, section, environment)) (key, interpolate_value(name, key, val, section, interpolator))
for key, val in (config_dict or {}).items() for key, val in (config_dict or {}).items()
) )
@ -24,9 +45,9 @@ def interpolate_environment_variables(config, section, environment):
) )
def interpolate_value(name, config_key, value, section, mapping): def interpolate_value(name, config_key, value, section, interpolator):
try: try:
return recursive_interpolate(value, mapping) return recursive_interpolate(value, interpolator)
except InvalidInterpolation as e: except InvalidInterpolation as e:
raise ConfigurationError( raise ConfigurationError(
'Invalid interpolation format for "{config_key}" option ' 'Invalid interpolation format for "{config_key}" option '
@ -37,25 +58,44 @@ def interpolate_value(name, config_key, value, section, mapping):
string=e.string)) string=e.string))
def recursive_interpolate(obj, mapping): def recursive_interpolate(obj, interpolator):
if isinstance(obj, six.string_types): if isinstance(obj, six.string_types):
return interpolate(obj, mapping) return interpolator.interpolate(obj)
elif isinstance(obj, dict): if isinstance(obj, dict):
return dict( return dict(
(key, recursive_interpolate(val, mapping)) (key, recursive_interpolate(val, interpolator))
for (key, val) in obj.items() for (key, val) in obj.items()
) )
elif isinstance(obj, list): if isinstance(obj, list):
return [recursive_interpolate(val, mapping) for val in obj] return [recursive_interpolate(val, interpolator) for val in obj]
else: return obj
return obj
def interpolate(string, mapping): class TemplateWithDefaults(Template):
try: idpattern = r'[_a-z][_a-z0-9]*(?::?-[_a-z0-9]+)?'
return Template(string).substitute(mapping)
except ValueError: # Modified from python2.7/string.py
raise InvalidInterpolation(string) def substitute(self, mapping):
# Helper function for .sub()
def convert(mo):
# Check the most common path first.
named = mo.group('named') or mo.group('braced')
if named is not None:
if ':-' in named:
var, _, default = named.partition(':-')
return mapping.get(var) or default
if '-' in named:
var, _, default = named.partition('-')
return mapping.get(var, default)
val = mapping[named]
return '%s' % (val,)
if mo.group('escaped') is not None:
return self.delimiter
if mo.group('invalid') is not None:
self._invalid(mo)
raise ValueError('Unrecognized named group in pattern',
self.pattern)
return self.pattern.sub(convert, self.template)
class InvalidInterpolation(Exception): class InvalidInterpolation(Exception):

View File

@ -1,21 +1,28 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
import os
import mock
import pytest import pytest
from compose.config.environment import Environment from compose.config.environment import Environment
from compose.config.interpolation import interpolate_environment_variables from compose.config.interpolation import interpolate_environment_variables
from compose.config.interpolation import Interpolator
from compose.config.interpolation import InvalidInterpolation
from compose.config.interpolation import TemplateWithDefaults
@pytest.yield_fixture @pytest.fixture
def mock_env(): def mock_env():
with mock.patch.dict(os.environ): return Environment({'USER': 'jenny', 'FOO': 'bar'})
os.environ['USER'] = 'jenny'
os.environ['FOO'] = 'bar'
yield @pytest.fixture
def variable_mapping():
return Environment({'FOO': 'first', 'BAR': ''})
@pytest.fixture
def defaults_interpolator(variable_mapping):
return Interpolator(TemplateWithDefaults, variable_mapping).interpolate
def test_interpolate_environment_variables_in_services(mock_env): def test_interpolate_environment_variables_in_services(mock_env):
@ -43,9 +50,8 @@ def test_interpolate_environment_variables_in_services(mock_env):
} }
} }
} }
assert interpolate_environment_variables( value = interpolate_environment_variables("2.0", services, 'service', mock_env)
services, 'service', Environment.from_env_file(None) assert value == expected
) == expected
def test_interpolate_environment_variables_in_volumes(mock_env): def test_interpolate_environment_variables_in_volumes(mock_env):
@ -69,6 +75,46 @@ def test_interpolate_environment_variables_in_volumes(mock_env):
}, },
'other': {}, 'other': {},
} }
assert interpolate_environment_variables( value = interpolate_environment_variables("2.0", volumes, 'volume', mock_env)
volumes, 'volume', Environment.from_env_file(None) assert value == expected
) == expected
def test_escaped_interpolation(defaults_interpolator):
assert defaults_interpolator('$${foo}') == '${foo}'
def test_invalid_interpolation(defaults_interpolator):
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('$}')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${}')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${ }')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${ foo}')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${foo }')
with pytest.raises(InvalidInterpolation):
defaults_interpolator('${foo!}')
def test_interpolate_missing_no_default(defaults_interpolator):
assert defaults_interpolator("This ${missing} var") == "This var"
assert defaults_interpolator("This ${BAR} var") == "This var"
def test_interpolate_with_value(defaults_interpolator):
assert defaults_interpolator("This $FOO var") == "This first var"
assert defaults_interpolator("This ${FOO} var") == "This first var"
def test_interpolate_missing_with_default(defaults_interpolator):
assert defaults_interpolator("ok ${missing:-def}") == "ok def"
assert defaults_interpolator("ok ${missing-def}") == "ok def"
def test_interpolate_with_empty_and_default_value(defaults_interpolator):
assert defaults_interpolator("ok ${BAR:-def}") == "ok def"
assert defaults_interpolator("ok ${BAR-def}") == "ok "

View File

@ -1,36 +0,0 @@
from __future__ import absolute_import
from __future__ import unicode_literals
import unittest
from compose.config.environment import Environment as bddict
from compose.config.interpolation import interpolate
from compose.config.interpolation import InvalidInterpolation
class InterpolationTest(unittest.TestCase):
def test_valid_interpolations(self):
self.assertEqual(interpolate('$foo', bddict(foo='hi')), 'hi')
self.assertEqual(interpolate('${foo}', bddict(foo='hi')), 'hi')
self.assertEqual(interpolate('${subject} love you', bddict(subject='i')), 'i love you')
self.assertEqual(interpolate('i ${verb} you', bddict(verb='love')), 'i love you')
self.assertEqual(interpolate('i love ${object}', bddict(object='you')), 'i love you')
def test_empty_value(self):
self.assertEqual(interpolate('${foo}', bddict(foo='')), '')
def test_unset_value(self):
self.assertEqual(interpolate('${foo}', bddict()), '')
def test_escaped_interpolation(self):
self.assertEqual(interpolate('$${foo}', bddict(foo='hi')), '${foo}')
def test_invalid_strings(self):
self.assertRaises(InvalidInterpolation, lambda: interpolate('${', bddict()))
self.assertRaises(InvalidInterpolation, lambda: interpolate('$}', bddict()))
self.assertRaises(InvalidInterpolation, lambda: interpolate('${}', bddict()))
self.assertRaises(InvalidInterpolation, lambda: interpolate('${ }', bddict()))
self.assertRaises(InvalidInterpolation, lambda: interpolate('${ foo}', bddict()))
self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo }', bddict()))
self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo!}', bddict()))