Make links unique-by-alias when merging

Factor out MergeDict from merge_service_dicts to reduce complexity below limit.

Signed-off-by: Daniel Nephin <dnephin@docker.com>
This commit is contained in:
Daniel Nephin 2016-02-01 13:47:13 -05:00
parent bf6a5d3e49
commit e32863f89e
2 changed files with 78 additions and 26 deletions

View File

@ -26,6 +26,7 @@ from .sort_services import get_service_name_from_network_mode
from .sort_services import sort_service_dicts from .sort_services import sort_service_dicts
from .types import parse_extra_hosts from .types import parse_extra_hosts
from .types import parse_restart_spec from .types import parse_restart_spec
from .types import ServiceLink
from .types import VolumeFromSpec from .types import VolumeFromSpec
from .types import VolumeSpec from .types import VolumeSpec
from .validation import match_named_volumes from .validation import match_named_volumes
@ -641,51 +642,79 @@ def merge_service_dicts_from_files(base, override, version):
return new_service return new_service
class MergeDict(dict):
"""A dict-like object responsible for merging two dicts into one."""
def __init__(self, base, override):
self.base = base
self.override = override
def needs_merge(self, field):
return field in self.base or field in self.override
def merge_field(self, field, merge_func, default=None):
if not self.needs_merge(field):
return
self[field] = merge_func(
self.base.get(field, default),
self.override.get(field, default))
def merge_mapping(self, field, parse_func):
if not self.needs_merge(field):
return
self[field] = parse_func(self.base.get(field))
self[field].update(parse_func(self.override.get(field)))
def merge_sequence(self, field, parse_func):
def parse_sequence_func(seq):
return to_mapping((parse_func(item) for item in seq), 'merge_field')
if not self.needs_merge(field):
return
merged = parse_sequence_func(self.base.get(field, []))
merged.update(parse_sequence_func(self.override.get(field, [])))
self[field] = [item.repr() for item in merged.values()]
def merge_scalar(self, field):
if self.needs_merge(field):
self[field] = self.override.get(field, self.base.get(field))
def merge_service_dicts(base, override, version): def merge_service_dicts(base, override, version):
d = {} md = MergeDict(base, override)
def merge_field(field, merge_func, default=None): md.merge_mapping('environment', parse_environment)
if field in base or field in override: md.merge_mapping('labels', parse_labels)
d[field] = merge_func( md.merge_mapping('ulimits', parse_ulimits)
base.get(field, default), md.merge_sequence('links', ServiceLink.parse)
override.get(field, default))
def merge_mapping(mapping, parse_func):
if mapping in base or mapping in override:
merged = parse_func(base.get(mapping, None))
merged.update(parse_func(override.get(mapping, None)))
d[mapping] = merged
merge_mapping('environment', parse_environment)
merge_mapping('labels', parse_labels)
merge_mapping('ulimits', parse_ulimits)
for field in ['volumes', 'devices']: for field in ['volumes', 'devices']:
merge_field(field, merge_path_mappings) md.merge_field(field, merge_path_mappings)
for field in [ for field in [
'depends_on', 'depends_on',
'expose', 'expose',
'external_links', 'external_links',
'links',
'ports', 'ports',
'volumes_from', 'volumes_from',
]: ]:
merge_field(field, operator.add, default=[]) md.merge_field(field, operator.add, default=[])
for field in ['dns', 'dns_search', 'env_file']: for field in ['dns', 'dns_search', 'env_file']:
merge_field(field, merge_list_or_string) md.merge_field(field, merge_list_or_string)
for field in set(ALLOWED_KEYS) - set(d): for field in set(ALLOWED_KEYS) - set(md):
if field in base or field in override: md.merge_scalar(field)
d[field] = override.get(field, base.get(field))
if version == V1: if version == V1:
legacy_v1_merge_image_or_build(d, base, override) legacy_v1_merge_image_or_build(md, base, override)
else: else:
merge_build(d, base, override) merge_build(md, base, override)
return d return dict(md)
def merge_build(output, base, override): def merge_build(output, base, override):
@ -919,6 +948,10 @@ def to_list(value):
return value return value
def to_mapping(sequence, key_field):
return {getattr(item, key_field): item for item in sequence}
def has_uppercase(name): def has_uppercase(name):
return any(char in string.ascii_uppercase for char in name) return any(char in string.ascii_uppercase for char in name)

View File

@ -168,3 +168,22 @@ class VolumeSpec(namedtuple('_VolumeSpec', 'external internal mode')):
@property @property
def is_named_volume(self): def is_named_volume(self):
return self.external and not self.external.startswith(('.', '/', '~')) return self.external and not self.external.startswith(('.', '/', '~'))
class ServiceLink(namedtuple('_ServiceLink', 'target alias')):
@classmethod
def parse(cls, link_spec):
target, _, alias = link_spec.partition(':')
if not alias:
alias = target
return cls(target, alias)
def repr(self):
if self.target == self.alias:
return self.target
return '{s.target}:{s.alias}'.format(s=self)
@property
def merge_field(self):
return self.alias