diff --git a/compose/config/config.py b/compose/config/config.py index b38942253..ac5e8d174 100644 --- a/compose/config/config.py +++ b/compose/config/config.py @@ -188,10 +188,10 @@ def find(base_dir, filenames): [ConfigFile.from_filename(f) for f in filenames]) -def validate_config_version(config_details): - main_file = config_details.config_files[0] +def validate_config_version(config_files): + main_file = config_files[0] validate_top_level_object(main_file) - for next_file in config_details.config_files[1:]: + for next_file in config_files[1:]: validate_top_level_object(next_file) if main_file.version != next_file.version: @@ -254,7 +254,7 @@ def load(config_details): Return a fully interpolated, extended and validated configuration. """ - validate_config_version(config_details) + validate_config_version(config_details.config_files) processed_files = [ process_config_file(config_file) @@ -267,9 +267,8 @@ def load(config_details): networks = load_mapping(config_details.config_files, 'get_networks', 'Network') service_dicts = load_services( config_details.working_dir, - main_file.filename, - [file.get_service_dicts() for file in config_details.config_files], - main_file.version) + main_file, + [file.get_service_dicts() for file in config_details.config_files]) return Config(main_file.version, service_dicts, volumes, networks) @@ -303,21 +302,21 @@ def load_mapping(config_files, get_func, entity_type): return mapping -def load_services(working_dir, filename, service_configs, version): +def load_services(working_dir, config_file, service_configs): def build_service(service_name, service_dict, service_names): service_config = ServiceConfig.with_abs_paths( working_dir, - filename, + config_file.filename, service_name, service_dict) - resolver = ServiceExtendsResolver(service_config, version) + resolver = ServiceExtendsResolver(service_config, config_file) service_dict = process_service(resolver.run()) - validate_service(service_dict, service_config.name, version) + validate_service(service_dict, service_config.name, config_file.version) service_dict = finalize_service( service_config._replace(config=service_dict), service_names, - version) + config_file.version) return service_dict def build_services(service_config): @@ -333,7 +332,7 @@ def load_services(working_dir, filename, service_configs, version): name: merge_service_dicts_from_files( base.get(name, {}), override.get(name, {}), - version) + config_file.version) for name in all_service_names } @@ -373,11 +372,11 @@ def process_config_file(config_file, service_name=None): class ServiceExtendsResolver(object): - def __init__(self, service_config, version, already_seen=None): + def __init__(self, service_config, config_file, already_seen=None): self.service_config = service_config self.working_dir = service_config.working_dir self.already_seen = already_seen or [] - self.version = version + self.config_file = config_file @property def signature(self): @@ -404,8 +403,10 @@ class ServiceExtendsResolver(object): config_path = self.get_extended_config_path(extends) service_name = extends['service'] + extends_file = ConfigFile.from_filename(config_path) + validate_config_version([self.config_file, extends_file]) extended_file = process_config_file( - ConfigFile.from_filename(config_path), + extends_file, service_name=service_name) service_config = extended_file.config[service_name] return config_path, service_config, service_name @@ -417,7 +418,7 @@ class ServiceExtendsResolver(object): extended_config_path, service_name, service_dict), - self.version, + self.config_file, already_seen=self.already_seen + [self.signature]) service_config = resolver.run() @@ -425,13 +426,12 @@ class ServiceExtendsResolver(object): validate_extended_service_dict( other_service_dict, extended_config_path, - service_name, - ) + service_name) return merge_service_dicts( other_service_dict, self.service_config.config, - self.version) + self.config_file.version) def get_extended_config_path(self, extends_options): """Service we are extending either has a value for 'file' set, which we diff --git a/tests/unit/config/config_test.py b/tests/unit/config/config_test.py index e24dc9041..cc2051363 100644 --- a/tests/unit/config/config_test.py +++ b/tests/unit/config/config_test.py @@ -25,14 +25,15 @@ V1 = 1 def make_service_dict(name, service_dict, working_dir, filename=None): + """Test helper function to construct a ServiceExtendsResolver """ - Test helper function to construct a ServiceExtendsResolver - """ - resolver = config.ServiceExtendsResolver(config.ServiceConfig( - working_dir=working_dir, - filename=filename, - name=name, - config=service_dict), version=1) + resolver = config.ServiceExtendsResolver( + config.ServiceConfig( + working_dir=working_dir, + filename=filename, + name=name, + config=service_dict), + config.ConfigFile(filename=filename, config={})) return config.process_service(resolver.run()) @@ -1888,6 +1889,28 @@ class ExtendsTest(unittest.TestCase): assert config == expected + def test_extends_with_mixed_versions_is_error(self): + tmpdir = py.test.ensuretemp('test_extends_with_mixed_version') + self.addCleanup(tmpdir.remove) + tmpdir.join('docker-compose.yml').write(""" + version: 2 + services: + web: + extends: + file: base.yml + service: base + image: busybox + """) + tmpdir.join('base.yml').write(""" + base: + volumes: ['/foo'] + ports: ['3000:3000'] + """) + + with pytest.raises(ConfigurationError) as exc: + load_from_filename(str(tmpdir.join('docker-compose.yml'))) + assert 'Version mismatch' in exc.exconly() + @pytest.mark.xfail(IS_WINDOWS_PLATFORM, reason='paths use slash') class ExpandPathTest(unittest.TestCase):