From e32863f89ebe0c70143695525e5062ae1c8f375c Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Mon, 1 Feb 2016 13:47:13 -0500 Subject: [PATCH] Make links unique-by-alias when merging Factor out MergeDict from merge_service_dicts to reduce complexity below limit. Signed-off-by: Daniel Nephin --- compose/config/config.py | 85 ++++++++++++++++++++++++++++------------ compose/config/types.py | 19 +++++++++ 2 files changed, 78 insertions(+), 26 deletions(-) diff --git a/compose/config/config.py b/compose/config/config.py index 07f622903..f362f1b80 100644 --- a/compose/config/config.py +++ b/compose/config/config.py @@ -26,6 +26,7 @@ from .sort_services import get_service_name_from_network_mode from .sort_services import sort_service_dicts from .types import parse_extra_hosts from .types import parse_restart_spec +from .types import ServiceLink from .types import VolumeFromSpec from .types import VolumeSpec from .validation import match_named_volumes @@ -641,51 +642,79 @@ def merge_service_dicts_from_files(base, override, version): return new_service +class MergeDict(dict): + """A dict-like object responsible for merging two dicts into one.""" + + def __init__(self, base, override): + self.base = base + self.override = override + + def needs_merge(self, field): + return field in self.base or field in self.override + + def merge_field(self, field, merge_func, default=None): + if not self.needs_merge(field): + return + + self[field] = merge_func( + self.base.get(field, default), + self.override.get(field, default)) + + def merge_mapping(self, field, parse_func): + if not self.needs_merge(field): + return + + self[field] = parse_func(self.base.get(field)) + self[field].update(parse_func(self.override.get(field))) + + def merge_sequence(self, field, parse_func): + def parse_sequence_func(seq): + return to_mapping((parse_func(item) for item in seq), 'merge_field') + + if not self.needs_merge(field): + return + + merged = parse_sequence_func(self.base.get(field, [])) + merged.update(parse_sequence_func(self.override.get(field, []))) + self[field] = [item.repr() for item in merged.values()] + + def merge_scalar(self, field): + if self.needs_merge(field): + self[field] = self.override.get(field, self.base.get(field)) + + def merge_service_dicts(base, override, version): - d = {} + md = MergeDict(base, override) - def merge_field(field, merge_func, default=None): - if field in base or field in override: - d[field] = merge_func( - base.get(field, default), - override.get(field, default)) - - def merge_mapping(mapping, parse_func): - if mapping in base or mapping in override: - merged = parse_func(base.get(mapping, None)) - merged.update(parse_func(override.get(mapping, None))) - d[mapping] = merged - - merge_mapping('environment', parse_environment) - merge_mapping('labels', parse_labels) - merge_mapping('ulimits', parse_ulimits) + md.merge_mapping('environment', parse_environment) + md.merge_mapping('labels', parse_labels) + md.merge_mapping('ulimits', parse_ulimits) + md.merge_sequence('links', ServiceLink.parse) for field in ['volumes', 'devices']: - merge_field(field, merge_path_mappings) + md.merge_field(field, merge_path_mappings) for field in [ 'depends_on', 'expose', 'external_links', - 'links', 'ports', 'volumes_from', ]: - merge_field(field, operator.add, default=[]) + md.merge_field(field, operator.add, default=[]) for field in ['dns', 'dns_search', 'env_file']: - merge_field(field, merge_list_or_string) + md.merge_field(field, merge_list_or_string) - for field in set(ALLOWED_KEYS) - set(d): - if field in base or field in override: - d[field] = override.get(field, base.get(field)) + for field in set(ALLOWED_KEYS) - set(md): + md.merge_scalar(field) if version == V1: - legacy_v1_merge_image_or_build(d, base, override) + legacy_v1_merge_image_or_build(md, base, override) else: - merge_build(d, base, override) + merge_build(md, base, override) - return d + return dict(md) def merge_build(output, base, override): @@ -919,6 +948,10 @@ def to_list(value): return value +def to_mapping(sequence, key_field): + return {getattr(item, key_field): item for item in sequence} + + def has_uppercase(name): return any(char in string.ascii_uppercase for char in name) diff --git a/compose/config/types.py b/compose/config/types.py index 9bda71806..fc3347c86 100644 --- a/compose/config/types.py +++ b/compose/config/types.py @@ -168,3 +168,22 @@ class VolumeSpec(namedtuple('_VolumeSpec', 'external internal mode')): @property def is_named_volume(self): return self.external and not self.external.startswith(('.', '/', '~')) + + +class ServiceLink(namedtuple('_ServiceLink', 'target alias')): + + @classmethod + def parse(cls, link_spec): + target, _, alias = link_spec.partition(':') + if not alias: + alias = target + return cls(target, alias) + + def repr(self): + if self.target == self.alias: + return self.target + return '{s.target}:{s.alias}'.format(s=self) + + @property + def merge_field(self): + return self.alias