From 79df2ebe1bbe81232acd84eeca7bf66af8e3004b Mon Sep 17 00:00:00 2001 From: Daniel Nephin <dnephin@docker.com> Date: Wed, 13 Jan 2016 15:19:02 -0500 Subject: [PATCH] Support variable interpolation for volumes and networks sections. Signed-off-by: Daniel Nephin <dnephin@docker.com> --- compose/config/config.py | 21 +++++--- compose/config/interpolation.py | 31 +++++------ tests/unit/config/config_test.py | 16 +++--- tests/unit/config/interpolation_test.py | 69 +++++++++++++++++++++++++ 4 files changed, 104 insertions(+), 33 deletions(-) create mode 100644 tests/unit/config/interpolation_test.py diff --git a/compose/config/config.py b/compose/config/config.py index c8d93faf6..f6df3d3bf 100644 --- a/compose/config/config.py +++ b/compose/config/config.py @@ -138,6 +138,9 @@ class ConfigFile(namedtuple('_ConfigFile', 'filename config')): def get_volumes(self): return {} if self.version == 1 else self.config.get('volumes', {}) + def get_networks(self): + return {} if self.version == 1 else self.config.get('networks', {}) + class Config(namedtuple('_Config', 'version services volumes networks')): """ @@ -258,8 +261,8 @@ def load(config_details): config_details = config_details._replace(config_files=processed_files) main_file = config_details.config_files[0] - volumes = load_mapping(config_details.config_files, 'volumes', 'Volume') - networks = load_mapping(config_details.config_files, 'networks', 'Network') + volumes = load_mapping(config_details.config_files, 'get_volumes', 'Volume') + networks = load_mapping(config_details.config_files, 'get_networks', 'Network') service_dicts = load_services( config_details.working_dir, main_file.filename, @@ -268,11 +271,11 @@ def load(config_details): return Config(main_file.version, service_dicts, volumes, networks) -def load_mapping(config_files, key, entity_type): +def load_mapping(config_files, get_func, entity_type): mapping = {} for config_file in config_files: - for name, config in config_file.config.get(key, {}).items(): + for name, config in getattr(config_file, get_func)().items(): mapping[name] = config or {} if not config: continue @@ -347,12 +350,16 @@ def process_config_file(config_file, service_name=None): service_dicts = config_file.get_service_dicts() validate_top_level_service_objects(config_file.filename, service_dicts) - # TODO: interpolate config in volumes/network sections as well - interpolated_config = interpolate_environment_variables(service_dicts) + interpolated_config = interpolate_environment_variables(service_dicts, 'service') if config_file.version == 2: processed_config = dict(config_file.config) - processed_config.update({'services': interpolated_config}) + processed_config['services'] = interpolated_config + processed_config['volumes'] = interpolate_environment_variables( + config_file.get_volumes(), 'volume') + processed_config['networks'] = interpolate_environment_variables( + config_file.get_networks(), 'network') + if config_file.version == 1: processed_config = interpolated_config diff --git a/compose/config/interpolation.py b/compose/config/interpolation.py index 7a7576448..e1c781fec 100644 --- a/compose/config/interpolation.py +++ b/compose/config/interpolation.py @@ -11,35 +11,32 @@ from .errors import ConfigurationError log = logging.getLogger(__name__) -def interpolate_environment_variables(service_dicts): +def interpolate_environment_variables(config, section): mapping = BlankDefaultDict(os.environ) + def process_item(name, config_dict): + return dict( + (key, interpolate_value(name, key, val, section, mapping)) + for key, val in (config_dict or {}).items() + ) + return dict( - (service_name, process_service(service_name, service_dict, mapping)) - for (service_name, service_dict) in service_dicts.items() + (name, process_item(name, config_dict)) + for name, config_dict in config.items() ) -def process_service(service_name, service_dict, mapping): - return dict( - (key, interpolate_value(service_name, key, val, mapping)) - for (key, val) in service_dict.items() - ) - - -def interpolate_value(service_name, config_key, value, mapping): +def interpolate_value(name, config_key, value, section, mapping): try: return recursive_interpolate(value, mapping) except InvalidInterpolation as e: raise ConfigurationError( 'Invalid interpolation format for "{config_key}" option ' - 'in service "{service_name}": "{string}"' - .format( + 'in {section} "{name}": "{string}"'.format( config_key=config_key, - service_name=service_name, - string=e.string, - ) - ) + name=name, + section=section, + string=e.string)) def recursive_interpolate(obj, mapping): diff --git a/tests/unit/config/config_test.py b/tests/unit/config/config_test.py index f0d432d58..f88166432 100644 --- a/tests/unit/config/config_test.py +++ b/tests/unit/config/config_test.py @@ -686,8 +686,8 @@ class ConfigTest(unittest.TestCase): ) ) - self.assertTrue(mock_logging.warn.called) - self.assertTrue(expected_warning_msg in mock_logging.warn.call_args[0][0]) + assert mock_logging.warn.called + assert expected_warning_msg in mock_logging.warn.call_args[0][0] def test_config_valid_environment_dict_key_contains_dashes(self): services = config.load( @@ -1664,15 +1664,13 @@ class ExtendsTest(unittest.TestCase): load_from_filename('tests/fixtures/extends/invalid-net.yml') @mock.patch.dict(os.environ) - def test_valid_interpolation_in_extended_service(self): - os.environ.update( - HOSTNAME_VALUE="penguin", - ) + def test_load_config_runs_interpolation_in_extended_service(self): + os.environ.update(HOSTNAME_VALUE="penguin") expected_interpolated_value = "host-penguin" - - service_dicts = load_from_filename('tests/fixtures/extends/valid-interpolation.yml') + service_dicts = load_from_filename( + 'tests/fixtures/extends/valid-interpolation.yml') for service in service_dicts: - self.assertTrue(service['hostname'], expected_interpolated_value) + assert service['hostname'] == expected_interpolated_value @pytest.mark.xfail(IS_WINDOWS_PLATFORM, reason='paths use slash') def test_volume_path(self): diff --git a/tests/unit/config/interpolation_test.py b/tests/unit/config/interpolation_test.py new file mode 100644 index 000000000..0691e8865 --- /dev/null +++ b/tests/unit/config/interpolation_test.py @@ -0,0 +1,69 @@ +from __future__ import absolute_import +from __future__ import unicode_literals + +import os + +import mock +import pytest + +from compose.config.interpolation import interpolate_environment_variables + + +@pytest.yield_fixture +def mock_env(): + with mock.patch.dict(os.environ): + os.environ['USER'] = 'jenny' + os.environ['FOO'] = 'bar' + yield + + +def test_interpolate_environment_variables_in_services(mock_env): + services = { + 'servivea': { + 'image': 'example:${USER}', + 'volumes': ['$FOO:/target'], + 'logging': { + 'driver': '${FOO}', + 'options': { + 'user': '$USER', + } + } + } + } + expected = { + 'servivea': { + 'image': 'example:jenny', + 'volumes': ['bar:/target'], + 'logging': { + 'driver': 'bar', + 'options': { + 'user': 'jenny', + } + } + } + } + assert interpolate_environment_variables(services, 'service') == expected + + +def test_interpolate_environment_variables_in_volumes(mock_env): + volumes = { + 'data': { + 'driver': '$FOO', + 'driver_opts': { + 'max': 2, + 'user': '${USER}' + } + }, + 'other': None, + } + expected = { + 'data': { + 'driver': 'bar', + 'driver_opts': { + 'max': 2, + 'user': 'jenny' + } + }, + 'other': {}, + } + assert interpolate_environment_variables(volumes, 'volume') == expected