Refactor config loading to move version check into ConfigFile.

Adds the cached_property package.

Signed-off-by: Daniel Nephin <dnephin@docker.com>
This commit is contained in:
Daniel Nephin 2016-01-13 13:28:39 -05:00
parent de949284f5
commit c3968a439f
4 changed files with 55 additions and 50 deletions

View File

@ -10,6 +10,7 @@ from collections import namedtuple
import six
import yaml
from cached_property import cached_property
from ..const import COMPOSEFILE_VERSIONS
from .errors import CircularReference
@ -119,11 +120,23 @@ class ConfigFile(namedtuple('_ConfigFile', 'filename config')):
def from_filename(cls, filename):
return cls(filename, load_yaml(filename))
def get_service_dicts(self, version):
return self.config if version == 1 else self.config.get('services', {})
@cached_property
def version(self):
if self.config is None:
return 1
version = self.config.get('version', 1)
if isinstance(version, dict):
log.warn("Unexpected type for field 'version', in file {} assuming "
"version is the name of a service, and defaulting to "
"Compose file version 1".format(self.filename))
return 1
return version
def get_volumes(self, version):
return {} if version == 1 else self.config.get('volumes', {})
def get_service_dicts(self):
return self.config if self.version == 1 else self.config.get('services', {})
def get_volumes(self):
return {} if self.version == 1 else self.config.get('volumes', {})
class Config(namedtuple('_Config', 'version services volumes')):
@ -168,32 +181,24 @@ def find(base_dir, filenames):
[ConfigFile.from_filename(f) for f in filenames])
def get_config_version(config_details):
def get_version(config):
if config.config is None:
return 1
version = config.config.get('version', 1)
if isinstance(version, dict):
# in that case 'version' is probably a service name, so assume
# this is a legacy (version=1) file
version = 1
return version
def validate_config_version(config_details):
main_file = config_details.config_files[0]
validate_top_level_object(main_file)
version = get_version(main_file)
for next_file in config_details.config_files[1:]:
validate_top_level_object(next_file)
next_file_version = get_version(next_file)
if version != next_file_version and next_file_version is not None:
if main_file.version != next_file.version:
raise ConfigurationError(
"Version mismatch: main file {0} specifies version {1} but "
"Version mismatch: file {0} specifies version {1} but "
"extension file {2} uses version {3}".format(
main_file.filename, version, next_file.filename, next_file_version
)
)
return version
main_file.filename,
main_file.version,
next_file.filename,
next_file.version))
if main_file.version not in COMPOSEFILE_VERSIONS:
raise ConfigurationError(
'Invalid Compose file version: {0}'.format(main_file.version))
def get_default_config_files(base_dir):
@ -242,23 +247,22 @@ def load(config_details):
Return a fully interpolated, extended and validated configuration.
"""
version = get_config_version(config_details)
if version not in COMPOSEFILE_VERSIONS:
raise ConfigurationError('Invalid config version provided: {0}'.format(version))
validate_config_version(config_details)
processed_files = [
process_config_file(config_file, version=version)
process_config_file(config_file)
for config_file in config_details.config_files
]
config_details = config_details._replace(config_files=processed_files)
main_file = config_details.config_files[0]
volumes = load_volumes(config_details.config_files)
service_dicts = load_services(
config_details.working_dir,
config_details.config_files[0].filename,
[file.get_service_dicts(version) for file in config_details.config_files],
version)
return Config(version, service_dicts, volumes)
main_file.filename,
[file.get_service_dicts() for file in config_details.config_files],
main_file.version)
return Config(main_file.version, service_dicts, volumes)
def load_volumes(config_files):
@ -328,27 +332,28 @@ def load_services(working_dir, filename, service_configs, version):
return build_services(service_config)
def process_config_file(config_file, version, service_name=None):
service_dicts = config_file.get_service_dicts(version)
validate_top_level_service_objects(
config_file.filename, service_dicts
)
def process_config_file(config_file, service_name=None):
service_dicts = config_file.get_service_dicts()
validate_top_level_service_objects(config_file.filename, service_dicts)
# TODO: interpolate config in volumes/network sections as well
interpolated_config = interpolate_environment_variables(service_dicts)
if version == 2:
if config_file.version == 2:
processed_config = dict(config_file.config)
processed_config.update({'services': interpolated_config})
if version == 1:
if config_file.version == 1:
processed_config = interpolated_config
validate_against_fields_schema(
processed_config, config_file.filename, version
)
config_file = config_file._replace(config=processed_config)
validate_against_fields_schema(config_file)
if service_name and service_name not in processed_config:
raise ConfigurationError(
"Cannot extend service '{}' in {}: Service not found".format(
service_name, config_file.filename))
return config_file._replace(config=processed_config)
return config_file
class ServiceExtendsResolver(object):
@ -385,8 +390,7 @@ class ServiceExtendsResolver(object):
extended_file = process_config_file(
ConfigFile.from_filename(config_path),
version=self.version, service_name=service_name
)
service_name=service_name)
service_config = extended_file.config[service_name]
return config_path, service_config, service_name

View File

@ -105,8 +105,7 @@ def validate_top_level_service_objects(filename, service_dicts):
def validate_top_level_object(config_file):
if not isinstance(config_file.config, dict):
raise ConfigurationError(
"Top level object in '{}' needs to be an object not '{}'. Check "
"that you have defined a service at the top level.".format(
"Top level object in '{}' needs to be an object not '{}'.".format(
config_file.filename,
type(config_file.config)))
@ -291,13 +290,13 @@ def process_errors(errors, service_name=None):
return '\n'.join(format_error_message(error, service_name) for error in errors)
def validate_against_fields_schema(config, filename, version):
schema_filename = "fields_schema_v{0}.json".format(version)
def validate_against_fields_schema(config_file):
schema_filename = "fields_schema_v{0}.json".format(config_file.version)
_validate_against_schema(
config,
config_file.config,
schema_filename,
format_checker=["ports", "expose", "bool-value-in-mapping"],
filename=filename)
filename=config_file.filename)
def validate_against_service_schema(config, service_name, version):

View File

@ -1,4 +1,5 @@
PyYAML==3.11
cached-property==1.2.0
dockerpty==0.3.4
docopt==0.6.1
enum34==1.0.4

View File

@ -28,6 +28,7 @@ def find_version(*file_paths):
install_requires = [
'cached-property >= 1.2.0',
'docopt >= 0.6.1, < 0.7',
'PyYAML >= 3.10, < 4',
'requests >= 2.6.1, < 2.8',