From e400c05de09bc1a1cdb287a83ac56a05d77ab044 Mon Sep 17 00:00:00 2001 From: Joffrey F Date: Wed, 3 Jan 2018 18:30:26 -0800 Subject: [PATCH] Support ${VAR:?err} syntax for mandatory variables Signed-off-by: Joffrey F --- compose/config/interpolation.py | 63 +++++++++++++++++++++---- tests/unit/cli/formatter_test.py | 1 - tests/unit/config/interpolation_test.py | 40 +++++++++++++++- 3 files changed, 94 insertions(+), 10 deletions(-) diff --git a/compose/config/interpolation.py b/compose/config/interpolation.py index 68d3be682..0f2630383 100644 --- a/compose/config/interpolation.py +++ b/compose/config/interpolation.py @@ -60,6 +60,15 @@ def interpolate_value(name, config_key, value, section, interpolator): name=name, section=section, string=e.string)) + except UnsetRequiredSubstitution as e: + raise ConfigurationError( + 'Missing mandatory value for "{config_key}" option in {section} "{name}": {err}'.format( + config_key=config_key, + name=name, + section=section, + err=e.err + ) + ) def recursive_interpolate(obj, interpolator, config_path): @@ -79,21 +88,54 @@ def recursive_interpolate(obj, interpolator, config_path): class TemplateWithDefaults(Template): - idpattern = r'[_a-z][_a-z0-9]*(?::?-[^}]*)?' + pattern = r""" + %(delim)s(?: + (?P%(delim)s) | + (?P%(id)s) | + {(?P%(bid)s)} | + (?P) + ) + """ % { + 'delim': re.escape('$'), + 'id': r'[_a-z][_a-z0-9]*', + 'bid': r'[_a-z][_a-z0-9]*(?:(?P:?[-?])[^}]*)?', + } + + @staticmethod + def process_braced_group(braced, sep, mapping): + if ':-' == sep: + var, _, default = braced.partition(':-') + return mapping.get(var) or default + elif '-' == sep: + var, _, default = braced.partition('-') + return mapping.get(var, default) + + elif ':?' == sep: + var, _, err = braced.partition(':?') + result = mapping.get(var) + if not result: + raise UnsetRequiredSubstitution(err) + return result + elif '?' == sep: + var, _, err = braced.partition('?') + if var in mapping: + return mapping.get(var) + raise UnsetRequiredSubstitution(err) # Modified from python2.7/string.py def substitute(self, mapping): # Helper function for .sub() + def convert(mo): - # Check the most common path first. named = mo.group('named') or mo.group('braced') + braced = mo.group('braced') + if braced is not None: + sep = mo.group('sep') + result = self.process_braced_group(braced, sep, mapping) + if result: + return result + if named is not None: - if ':-' in named: - var, _, default = named.partition(':-') - return mapping.get(var) or default - if '-' in named: - var, _, default = named.partition('-') - return mapping.get(var, default) val = mapping[named] return '%s' % (val,) if mo.group('escaped') is not None: @@ -110,6 +152,11 @@ class InvalidInterpolation(Exception): self.string = string +class UnsetRequiredSubstitution(Exception): + def __init__(self, custom_err_msg): + self.err = custom_err_msg + + PATH_JOKER = '[^.]+' diff --git a/tests/unit/cli/formatter_test.py b/tests/unit/cli/formatter_test.py index 4aa025e69..e68572511 100644 --- a/tests/unit/cli/formatter_test.py +++ b/tests/unit/cli/formatter_test.py @@ -37,7 +37,6 @@ class ConsoleWarningFormatterTestCase(unittest.TestCase): def test_format_unicode_info(self): message = b'\xec\xa0\x95\xec\x88\x98\xec\xa0\x95' output = self.formatter.format(make_log_record(logging.INFO, message)) - print(output) assert output == message.decode('utf-8') def test_format_unicode_warn(self): diff --git a/tests/unit/config/interpolation_test.py b/tests/unit/config/interpolation_test.py index 702ea682d..dfeba96d0 100644 --- a/tests/unit/config/interpolation_test.py +++ b/tests/unit/config/interpolation_test.py @@ -9,6 +9,7 @@ from compose.config.interpolation import interpolate_environment_variables from compose.config.interpolation import Interpolator from compose.config.interpolation import InvalidInterpolation from compose.config.interpolation import TemplateWithDefaults +from compose.config.interpolation import UnsetRequiredSubstitution from compose.const import COMPOSEFILE_V2_0 as V2_0 from compose.const import COMPOSEFILE_V2_3 as V2_3 from compose.const import COMPOSEFILE_V3_4 as V3_4 @@ -357,9 +358,46 @@ def test_interpolate_with_value(defaults_interpolator): def test_interpolate_missing_with_default(defaults_interpolator): assert defaults_interpolator("ok ${missing:-def}") == "ok def" assert defaults_interpolator("ok ${missing-def}") == "ok def" - assert defaults_interpolator("ok ${BAR:-/non:-alphanumeric}") == "ok /non:-alphanumeric" def test_interpolate_with_empty_and_default_value(defaults_interpolator): assert defaults_interpolator("ok ${BAR:-def}") == "ok def" assert defaults_interpolator("ok ${BAR-def}") == "ok " + + +def test_interpolate_mandatory_values(defaults_interpolator): + assert defaults_interpolator("ok ${FOO:?bar}") == "ok first" + assert defaults_interpolator("ok ${FOO?bar}") == "ok first" + assert defaults_interpolator("ok ${BAR?bar}") == "ok " + + with pytest.raises(UnsetRequiredSubstitution) as e: + defaults_interpolator("not ok ${BAR:?high bar}") + assert e.value.err == 'high bar' + + with pytest.raises(UnsetRequiredSubstitution) as e: + defaults_interpolator("not ok ${BAZ?dropped the bazz}") + assert e.value.err == 'dropped the bazz' + + +def test_interpolate_mandatory_no_err_msg(defaults_interpolator): + with pytest.raises(UnsetRequiredSubstitution) as e: + defaults_interpolator("not ok ${BAZ?}") + + assert e.value.err == '' + + +def test_interpolate_mixed_separators(defaults_interpolator): + assert defaults_interpolator("ok ${BAR:-/non:-alphanumeric}") == "ok /non:-alphanumeric" + assert defaults_interpolator("ok ${BAR:-:?wwegegr??:?}") == "ok :?wwegegr??:?" + assert defaults_interpolator("ok ${BAR-:-hello}") == 'ok ' + + with pytest.raises(UnsetRequiredSubstitution) as e: + defaults_interpolator("not ok ${BAR:?xazz:-redf}") + assert e.value.err == 'xazz:-redf' + + assert defaults_interpolator("ok ${BAR?...:?bar}") == "ok " + + +def test_unbraced_separators(defaults_interpolator): + assert defaults_interpolator("ok $FOO:-bar") == "ok first:-bar" + assert defaults_interpolator("ok $BAZ?error") == "ok ?error"