From fe08be698d36d42a66839ce284989947220931cd Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Thu, 10 Mar 2016 17:33:01 -0500 Subject: [PATCH] Support inline default values. Signed-off-by: Daniel Nephin --- compose/config/config.py | 22 +++++--- compose/config/interpolation.py | 74 +++++++++++++++++++------ tests/unit/config/interpolation_test.py | 74 ++++++++++++++++++++----- tests/unit/interpolation_test.py | 36 ------------ 4 files changed, 130 insertions(+), 76 deletions(-) delete mode 100644 tests/unit/interpolation_test.py diff --git a/compose/config/config.py b/compose/config/config.py index aea1e0949..4d32b50c4 100644 --- a/compose/config/config.py +++ b/compose/config/config.py @@ -413,31 +413,35 @@ def load_services(config_details, config_file): return build_services(service_config) -def interpolate_config_section(filename, config, section, environment): - validate_config_section(filename, config, section) - return interpolate_environment_variables(config, section, environment) +def interpolate_config_section(config_file, config, section, environment): + validate_config_section(config_file.filename, config, section) + return interpolate_environment_variables( + config_file.version, + config, + section, + environment) def process_config_file(config_file, environment, service_name=None): services = interpolate_config_section( - config_file.filename, + config_file, config_file.get_service_dicts(), 'service', - environment,) + environment) if config_file.version in (V2_0, V2_1): processed_config = dict(config_file.config) processed_config['services'] = services processed_config['volumes'] = interpolate_config_section( - config_file.filename, + config_file, config_file.get_volumes(), 'volume', - environment,) + environment) processed_config['networks'] = interpolate_config_section( - config_file.filename, + config_file, config_file.get_networks(), 'network', - environment,) + environment) if config_file.version == V1: processed_config = services diff --git a/compose/config/interpolation.py b/compose/config/interpolation.py index 63020d91a..cb841437c 100644 --- a/compose/config/interpolation.py +++ b/compose/config/interpolation.py @@ -7,14 +7,35 @@ from string import Template import six from .errors import ConfigurationError +from compose.const import COMPOSEFILE_V1 as V1 +from compose.const import COMPOSEFILE_V2_0 as V2_0 + + log = logging.getLogger(__name__) -def interpolate_environment_variables(config, section, environment): +class Interpolator(object): + + def __init__(self, templater, mapping): + self.templater = templater + self.mapping = mapping + + def interpolate(self, string): + try: + return self.templater(string).substitute(self.mapping) + except ValueError: + raise InvalidInterpolation(string) + + +def interpolate_environment_variables(version, config, section, environment): + if version in (V2_0, V1): + interpolator = Interpolator(Template, environment) + else: + interpolator = Interpolator(TemplateWithDefaults, environment) def process_item(name, config_dict): return dict( - (key, interpolate_value(name, key, val, section, environment)) + (key, interpolate_value(name, key, val, section, interpolator)) for key, val in (config_dict or {}).items() ) @@ -24,9 +45,9 @@ def interpolate_environment_variables(config, section, environment): ) -def interpolate_value(name, config_key, value, section, mapping): +def interpolate_value(name, config_key, value, section, interpolator): try: - return recursive_interpolate(value, mapping) + return recursive_interpolate(value, interpolator) except InvalidInterpolation as e: raise ConfigurationError( 'Invalid interpolation format for "{config_key}" option ' @@ -37,25 +58,44 @@ def interpolate_value(name, config_key, value, section, mapping): string=e.string)) -def recursive_interpolate(obj, mapping): +def recursive_interpolate(obj, interpolator): if isinstance(obj, six.string_types): - return interpolate(obj, mapping) - elif isinstance(obj, dict): + return interpolator.interpolate(obj) + if isinstance(obj, dict): return dict( - (key, recursive_interpolate(val, mapping)) + (key, recursive_interpolate(val, interpolator)) for (key, val) in obj.items() ) - elif isinstance(obj, list): - return [recursive_interpolate(val, mapping) for val in obj] - else: - return obj + if isinstance(obj, list): + return [recursive_interpolate(val, interpolator) for val in obj] + return obj -def interpolate(string, mapping): - try: - return Template(string).substitute(mapping) - except ValueError: - raise InvalidInterpolation(string) +class TemplateWithDefaults(Template): + idpattern = r'[_a-z][_a-z0-9]*(?::?-[_a-z0-9]+)?' + + # Modified from python2.7/string.py + def substitute(self, mapping): + # Helper function for .sub() + def convert(mo): + # Check the most common path first. + named = mo.group('named') or mo.group('braced') + if named is not None: + if ':-' in named: + var, _, default = named.partition(':-') + return mapping.get(var) or default + if '-' in named: + var, _, default = named.partition('-') + return mapping.get(var, default) + val = mapping[named] + return '%s' % (val,) + if mo.group('escaped') is not None: + return self.delimiter + if mo.group('invalid') is not None: + self._invalid(mo) + raise ValueError('Unrecognized named group in pattern', + self.pattern) + return self.pattern.sub(convert, self.template) class InvalidInterpolation(Exception): diff --git a/tests/unit/config/interpolation_test.py b/tests/unit/config/interpolation_test.py index 42b5db6e9..224444950 100644 --- a/tests/unit/config/interpolation_test.py +++ b/tests/unit/config/interpolation_test.py @@ -1,21 +1,28 @@ from __future__ import absolute_import from __future__ import unicode_literals -import os - -import mock import pytest from compose.config.environment import Environment from compose.config.interpolation import interpolate_environment_variables +from compose.config.interpolation import Interpolator +from compose.config.interpolation import InvalidInterpolation +from compose.config.interpolation import TemplateWithDefaults -@pytest.yield_fixture +@pytest.fixture def mock_env(): - with mock.patch.dict(os.environ): - os.environ['USER'] = 'jenny' - os.environ['FOO'] = 'bar' - yield + return Environment({'USER': 'jenny', 'FOO': 'bar'}) + + +@pytest.fixture +def variable_mapping(): + return Environment({'FOO': 'first', 'BAR': ''}) + + +@pytest.fixture +def defaults_interpolator(variable_mapping): + return Interpolator(TemplateWithDefaults, variable_mapping).interpolate def test_interpolate_environment_variables_in_services(mock_env): @@ -43,9 +50,8 @@ def test_interpolate_environment_variables_in_services(mock_env): } } } - assert interpolate_environment_variables( - services, 'service', Environment.from_env_file(None) - ) == expected + value = interpolate_environment_variables("2.0", services, 'service', mock_env) + assert value == expected def test_interpolate_environment_variables_in_volumes(mock_env): @@ -69,6 +75,46 @@ def test_interpolate_environment_variables_in_volumes(mock_env): }, 'other': {}, } - assert interpolate_environment_variables( - volumes, 'volume', Environment.from_env_file(None) - ) == expected + value = interpolate_environment_variables("2.0", volumes, 'volume', mock_env) + assert value == expected + + +def test_escaped_interpolation(defaults_interpolator): + assert defaults_interpolator('$${foo}') == '${foo}' + + +def test_invalid_interpolation(defaults_interpolator): + with pytest.raises(InvalidInterpolation): + defaults_interpolator('${') + with pytest.raises(InvalidInterpolation): + defaults_interpolator('$}') + with pytest.raises(InvalidInterpolation): + defaults_interpolator('${}') + with pytest.raises(InvalidInterpolation): + defaults_interpolator('${ }') + with pytest.raises(InvalidInterpolation): + defaults_interpolator('${ foo}') + with pytest.raises(InvalidInterpolation): + defaults_interpolator('${foo }') + with pytest.raises(InvalidInterpolation): + defaults_interpolator('${foo!}') + + +def test_interpolate_missing_no_default(defaults_interpolator): + assert defaults_interpolator("This ${missing} var") == "This var" + assert defaults_interpolator("This ${BAR} var") == "This var" + + +def test_interpolate_with_value(defaults_interpolator): + assert defaults_interpolator("This $FOO var") == "This first var" + assert defaults_interpolator("This ${FOO} var") == "This first var" + + +def test_interpolate_missing_with_default(defaults_interpolator): + assert defaults_interpolator("ok ${missing:-def}") == "ok def" + assert defaults_interpolator("ok ${missing-def}") == "ok def" + + +def test_interpolate_with_empty_and_default_value(defaults_interpolator): + assert defaults_interpolator("ok ${BAR:-def}") == "ok def" + assert defaults_interpolator("ok ${BAR-def}") == "ok " diff --git a/tests/unit/interpolation_test.py b/tests/unit/interpolation_test.py deleted file mode 100644 index c3050c2ca..000000000 --- a/tests/unit/interpolation_test.py +++ /dev/null @@ -1,36 +0,0 @@ -from __future__ import absolute_import -from __future__ import unicode_literals - -import unittest - -from compose.config.environment import Environment as bddict -from compose.config.interpolation import interpolate -from compose.config.interpolation import InvalidInterpolation - - -class InterpolationTest(unittest.TestCase): - def test_valid_interpolations(self): - self.assertEqual(interpolate('$foo', bddict(foo='hi')), 'hi') - self.assertEqual(interpolate('${foo}', bddict(foo='hi')), 'hi') - - self.assertEqual(interpolate('${subject} love you', bddict(subject='i')), 'i love you') - self.assertEqual(interpolate('i ${verb} you', bddict(verb='love')), 'i love you') - self.assertEqual(interpolate('i love ${object}', bddict(object='you')), 'i love you') - - def test_empty_value(self): - self.assertEqual(interpolate('${foo}', bddict(foo='')), '') - - def test_unset_value(self): - self.assertEqual(interpolate('${foo}', bddict()), '') - - def test_escaped_interpolation(self): - self.assertEqual(interpolate('$${foo}', bddict(foo='hi')), '${foo}') - - def test_invalid_strings(self): - self.assertRaises(InvalidInterpolation, lambda: interpolate('${', bddict())) - self.assertRaises(InvalidInterpolation, lambda: interpolate('$}', bddict())) - self.assertRaises(InvalidInterpolation, lambda: interpolate('${}', bddict())) - self.assertRaises(InvalidInterpolation, lambda: interpolate('${ }', bddict())) - self.assertRaises(InvalidInterpolation, lambda: interpolate('${ foo}', bddict())) - self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo }', bddict())) - self.assertRaises(InvalidInterpolation, lambda: interpolate('${foo!}', bddict()))