From 907918b492deb7543a2ce4385ea5a3a3228ff93d Mon Sep 17 00:00:00 2001 From: Aanand Prasad Date: Mon, 30 Mar 2015 18:20:34 -0400 Subject: [PATCH] Merge multi-value options when extending Closes #1143. Signed-off-by: Aanand Prasad --- compose/config.py | 30 ++++++++++++++--- tests/unit/config_test.py | 71 +++++++++++++++++++++++++++++++++++---- 2 files changed, 90 insertions(+), 11 deletions(-) diff --git a/compose/config.py b/compose/config.py index dea734d51..022069fdf 100644 --- a/compose/config.py +++ b/compose/config.py @@ -195,10 +195,23 @@ def merge_service_dicts(base, override): if 'build' in override and 'image' in d: del d['image'] - for k in ALLOWED_KEYS: - if k not in ['environment', 'volumes']: - if k in override: - d[k] = override[k] + list_keys = ['ports', 'expose', 'external_links'] + + for key in list_keys: + if key in base or key in override: + d[key] = base.get(key, []) + override.get(key, []) + + list_or_string_keys = ['dns', 'dns_search'] + + for key in list_or_string_keys: + if key in base or key in override: + d[key] = to_list(base.get(key)) + to_list(override.get(key)) + + already_merged_keys = ['environment', 'volumes'] + list_keys + list_or_string_keys + + for k in set(ALLOWED_KEYS) - set(already_merged_keys): + if k in override: + d[k] = override[k] return d @@ -354,6 +367,15 @@ def expand_path(working_dir, path): return os.path.abspath(os.path.join(working_dir, path)) +def to_list(value): + if value is None: + return [] + elif isinstance(value, six.string_types): + return [value] + else: + return value + + def get_service_name_from_net(net_config): if not net_config: return diff --git a/tests/unit/config_test.py b/tests/unit/config_test.py index 280034449..af3bebb33 100644 --- a/tests/unit/config_test.py +++ b/tests/unit/config_test.py @@ -40,40 +40,40 @@ class ConfigTest(unittest.TestCase): config.make_service_dict('foo', {'ports': ['8000']}) -class MergeTest(unittest.TestCase): - def test_merge_volumes_empty(self): +class MergeVolumesTest(unittest.TestCase): + def test_empty(self): service_dict = config.merge_service_dicts({}, {}) self.assertNotIn('volumes', service_dict) - def test_merge_volumes_no_override(self): + def test_no_override(self): service_dict = config.merge_service_dicts( {'volumes': ['/foo:/code', '/data']}, {}, ) self.assertEqual(set(service_dict['volumes']), set(['/foo:/code', '/data'])) - def test_merge_volumes_no_base(self): + def test_no_base(self): service_dict = config.merge_service_dicts( {}, {'volumes': ['/bar:/code']}, ) self.assertEqual(set(service_dict['volumes']), set(['/bar:/code'])) - def test_merge_volumes_override_explicit_path(self): + def test_override_explicit_path(self): service_dict = config.merge_service_dicts( {'volumes': ['/foo:/code', '/data']}, {'volumes': ['/bar:/code']}, ) self.assertEqual(set(service_dict['volumes']), set(['/bar:/code', '/data'])) - def test_merge_volumes_add_explicit_path(self): + def test_add_explicit_path(self): service_dict = config.merge_service_dicts( {'volumes': ['/foo:/code', '/data']}, {'volumes': ['/bar:/code', '/quux:/data']}, ) self.assertEqual(set(service_dict['volumes']), set(['/bar:/code', '/quux:/data'])) - def test_merge_volumes_remove_explicit_path(self): + def test_remove_explicit_path(self): service_dict = config.merge_service_dicts( {'volumes': ['/foo:/code', '/quux:/data']}, {'volumes': ['/bar:/code', '/data']}, @@ -114,6 +114,63 @@ class MergeTest(unittest.TestCase): ) +class MergeListsTest(unittest.TestCase): + def test_empty(self): + service_dict = config.merge_service_dicts({}, {}) + self.assertNotIn('ports', service_dict) + + def test_no_override(self): + service_dict = config.merge_service_dicts( + {'ports': ['10:8000', '9000']}, + {}, + ) + self.assertEqual(set(service_dict['ports']), set(['10:8000', '9000'])) + + def test_no_base(self): + service_dict = config.merge_service_dicts( + {}, + {'ports': ['10:8000', '9000']}, + ) + self.assertEqual(set(service_dict['ports']), set(['10:8000', '9000'])) + + def test_add_item(self): + service_dict = config.merge_service_dicts( + {'ports': ['10:8000', '9000']}, + {'ports': ['20:8000']}, + ) + self.assertEqual(set(service_dict['ports']), set(['10:8000', '9000', '20:8000'])) + + +class MergeStringsOrListsTest(unittest.TestCase): + def test_no_override(self): + service_dict = config.merge_service_dicts( + {'dns': '8.8.8.8'}, + {}, + ) + self.assertEqual(set(service_dict['dns']), set(['8.8.8.8'])) + + def test_no_base(self): + service_dict = config.merge_service_dicts( + {}, + {'dns': '8.8.8.8'}, + ) + self.assertEqual(set(service_dict['dns']), set(['8.8.8.8'])) + + def test_add_string(self): + service_dict = config.merge_service_dicts( + {'dns': ['8.8.8.8']}, + {'dns': '9.9.9.9'}, + ) + self.assertEqual(set(service_dict['dns']), set(['8.8.8.8', '9.9.9.9'])) + + def test_add_list(self): + service_dict = config.merge_service_dicts( + {'dns': '8.8.8.8'}, + {'dns': ['9.9.9.9']}, + ) + self.assertEqual(set(service_dict['dns']), set(['8.8.8.8', '9.9.9.9'])) + + class EnvTest(unittest.TestCase): def test_parse_environment_as_list(self): environment = [