diff --git a/compose/config.py b/compose/config.py index 2427476f7..c054213eb 100644 --- a/compose/config.py +++ b/compose/config.py @@ -158,7 +158,7 @@ class ServiceLoader(object): if 'extends' not in service_dict: return service_dict - extends_options = validate_extends_options(service_dict['name'], service_dict['extends']) + extends_options = self.validate_extends_options(service_dict['name'], service_dict['extends']) if self.working_dir is None: raise Exception("No working_dir passed to ServiceLoader()") @@ -194,25 +194,29 @@ class ServiceLoader(object): def signature(self, name): return (self.filename, name) + def validate_extends_options(self, service_name, extends_options): + error_prefix = "Invalid 'extends' configuration for %s:" % service_name -def validate_extends_options(service_name, extends_options): - error_prefix = "Invalid 'extends' configuration for %s:" % service_name + if not isinstance(extends_options, dict): + raise ConfigurationError("%s must be a dictionary" % error_prefix) - if not isinstance(extends_options, dict): - raise ConfigurationError("%s must be a dictionary" % error_prefix) - - if 'service' not in extends_options: - raise ConfigurationError( - "%s you need to specify a service, e.g. 'service: web'" % error_prefix - ) - - for k, _ in extends_options.items(): - if k not in ['file', 'service']: + if 'service' not in extends_options: raise ConfigurationError( - "%s unsupported configuration option '%s'" % (error_prefix, k) + "%s you need to specify a service, e.g. 'service: web'" % error_prefix ) - return extends_options + if 'file' not in extends_options and self.filename is None: + raise ConfigurationError( + "%s you need to specify a 'file', e.g. 'file: something.yml'" % error_prefix + ) + + for k, _ in extends_options.items(): + if k not in ['file', 'service']: + raise ConfigurationError( + "%s unsupported configuration option '%s'" % (error_prefix, k) + ) + + return extends_options def validate_extended_service_dict(service_dict, filename, service): diff --git a/tests/unit/config_test.py b/tests/unit/config_test.py index 48dd9afa8..4047a7253 100644 --- a/tests/unit/config_test.py +++ b/tests/unit/config_test.py @@ -459,6 +459,14 @@ class ExtendsTest(unittest.TestCase): self.assertRaisesRegexp(config.ConfigurationError, 'what', load_config) + def test_extends_validation_no_file_key_no_filename_set(self): + dictionary = {'extends': {'service': 'web'}} + + def load_config(): + return config.make_service_dict('myweb', dictionary, working_dir='tests/fixtures/extends') + + self.assertRaisesRegexp(config.ConfigurationError, 'file', load_config) + def test_extends_validation_valid_config(self): dictionary = {'extends': {'service': 'web', 'file': 'common.yml'}}