Handle volume driver change error in config.

Assume version=1 if file is empty in get_config_version
Empty files are invalid anyway, so this simplifies the algorithm
somewhat.
https://github.com/docker/compose/pull/2421#discussion_r47223144

Don't leak version considerations in interpolation/service validation

Signed-off-by: Joffrey F <joffrey@docker.com>
This commit is contained in:
Joffrey F 2015-12-11 17:21:04 -08:00
parent f3a9533dc0
commit a7689f3da8
6 changed files with 94 additions and 21 deletions

View File

@ -118,6 +118,9 @@ class ConfigFile(namedtuple('_ConfigFile', 'filename config')):
def from_filename(cls, filename): def from_filename(cls, filename):
return cls(filename, load_yaml(filename)) return cls(filename, load_yaml(filename))
def get_service_dicts(self, version):
return self.config if version == 1 else self.config.get('services', {})
class Config(namedtuple('_Config', 'version services volumes')): class Config(namedtuple('_Config', 'version services volumes')):
""" """
@ -164,9 +167,11 @@ def find(base_dir, filenames):
def get_config_version(config_details): def get_config_version(config_details):
def get_version(config): def get_version(config):
if config.config is None: if config.config is None:
return None return 1
version = config.config.get('version', 1) version = config.config.get('version', 1)
if isinstance(version, dict): if isinstance(version, dict):
# in that case 'version' is probably a service name, so assume
# this is a legacy (version=1) file
version = 1 version = 1
return version return version
@ -176,9 +181,6 @@ def get_config_version(config_details):
for next_file in config_details.config_files[1:]: for next_file in config_details.config_files[1:]:
validate_top_level_object(next_file) validate_top_level_object(next_file)
next_file_version = get_version(next_file) next_file_version = get_version(next_file)
if version is None:
version = next_file_version
continue
if version != next_file_version and next_file_version is not None: if version != next_file_version and next_file_version is not None:
raise ConfigurationError( raise ConfigurationError(
@ -316,8 +318,16 @@ def load_services(working_dir, config_files, version):
def process_config_file(config_file, version, service_name=None): def process_config_file(config_file, version, service_name=None):
validate_top_level_service_objects(config_file, version) service_dicts = config_file.get_service_dicts(version)
processed_config = interpolate_environment_variables(config_file.config, version) validate_top_level_service_objects(
config_file.filename, service_dicts
)
interpolated_config = interpolate_environment_variables(service_dicts)
if version == 2:
processed_config = dict(config_file.config)
processed_config.update({'services': interpolated_config})
if version == 1:
processed_config = interpolated_config
validate_against_fields_schema( validate_against_fields_schema(
processed_config, config_file.filename, version processed_config, config_file.filename, version
) )

View File

@ -8,19 +8,13 @@ from .errors import ConfigurationError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def interpolate_environment_variables(config, version): def interpolate_environment_variables(service_dicts):
mapping = BlankDefaultDict(os.environ) mapping = BlankDefaultDict(os.environ)
service_dicts = config if version == 1 else config.get('services', {})
interpolated = dict( return dict(
(service_name, process_service(service_name, service_dict, mapping)) (service_name, process_service(service_name, service_dict, mapping))
for (service_name, service_dict) in service_dicts.items() for (service_name, service_dict) in service_dicts.items()
) )
if version == 1:
return interpolated
result = dict(config)
result.update({'services': interpolated})
return result
def process_service(service_name, service_dict, mapping): def process_service(service_name, service_dict, mapping):

View File

@ -74,19 +74,18 @@ def format_boolean_in_environment(instance):
return True return True
def validate_top_level_service_objects(config_file, version): def validate_top_level_service_objects(filename, service_dicts):
"""Perform some high level validation of the service name and value. """Perform some high level validation of the service name and value.
This validation must happen before interpolation, which must happen This validation must happen before interpolation, which must happen
before the rest of validation, which is why it's separate from the before the rest of validation, which is why it's separate from the
rest of the service validation. rest of the service validation.
""" """
service_dicts = config_file.config if version == 1 else config_file.config.get('services', {})
for service_name, service_dict in service_dicts.items(): for service_name, service_dict in service_dicts.items():
if not isinstance(service_name, six.string_types): if not isinstance(service_name, six.string_types):
raise ConfigurationError( raise ConfigurationError(
"In file '{}' service name: {} needs to be a string, eg '{}'".format( "In file '{}' service name: {} needs to be a string, eg '{}'".format(
config_file.filename, filename,
service_name, service_name,
service_name)) service_name))
@ -95,8 +94,9 @@ def validate_top_level_service_objects(config_file, version):
"In file '{}' service '{}' doesn\'t have any configuration options. " "In file '{}' service '{}' doesn\'t have any configuration options. "
"All top level keys in your docker-compose.yml must map " "All top level keys in your docker-compose.yml must map "
"to a dictionary of configuration options.".format( "to a dictionary of configuration options.".format(
config_file.filename, filename, service_name
service_name)) )
)
def validate_top_level_object(config_file): def validate_top_level_object(config_file):

View File

@ -236,6 +236,18 @@ class Project(object):
raise ConfigurationError( raise ConfigurationError(
'Volume %s specifies nonexistent driver %s' % (volume.name, volume.driver) 'Volume %s specifies nonexistent driver %s' % (volume.name, volume.driver)
) )
except APIError as e:
if 'Choose a different volume name' in str(e):
raise ConfigurationError(
'Configuration for volume {0} specifies driver {1}, but '
'a volume with the same name uses a different driver '
'({3}). If you wish to use the new configuration, please '
'remove the existing volume "{2}" first:\n'
'$ docker volume rm {2}'.format(
volume.name, volume.driver, volume.full_name,
volume.inspect()['Driver']
)
)
def restart(self, service_names=None, **options): def restart(self, service_names=None, **options):
containers = self.containers(service_names, stopped=True) containers = self.containers(service_names, stopped=True)

View File

@ -579,11 +579,11 @@ class ProjectTest(DockerClientTestCase):
vol_name = '{0:x}'.format(random.getrandbits(32)) vol_name = '{0:x}'.format(random.getrandbits(32))
config_data = config.Config( config_data = config.Config(
2, [{ version=2, services=[{
'name': 'web', 'name': 'web',
'image': 'busybox:latest', 'image': 'busybox:latest',
'command': 'top' 'command': 'top'
}], {vol_name: {'driver': 'foobar'}} }], volumes={vol_name: {'driver': 'foobar'}}
) )
project = Project.from_config( project = Project.from_config(
@ -592,3 +592,37 @@ class ProjectTest(DockerClientTestCase):
) )
with self.assertRaises(config.ConfigurationError): with self.assertRaises(config.ConfigurationError):
project.initialize_volumes() project.initialize_volumes()
def test_project_up_updated_driver(self):
vol_name = '{0:x}'.format(random.getrandbits(32))
full_vol_name = 'composetest_{0}'.format(vol_name)
config_data = config.Config(
version=2, services=[{
'name': 'web',
'image': 'busybox:latest',
'command': 'top'
}], volumes={vol_name: {'driver': 'local'}}
)
project = Project.from_config(
name='composetest',
config_data=config_data, client=self.client
)
project.initialize_volumes()
volume_data = self.client.inspect_volume(full_vol_name)
self.assertEqual(volume_data['Name'], full_vol_name)
self.assertEqual(volume_data['Driver'], 'local')
config_data = config_data._replace(
volumes={vol_name: {'driver': 'smb'}}
)
project = Project.from_config(
name='composetest',
config_data=config_data, client=self.client
)
with self.assertRaises(config.ConfigurationError) as e:
project.initialize_volumes()
assert 'Configuration for volume {0} specifies driver smb'.format(
vol_name
) in str(e.exception)

View File

@ -286,6 +286,18 @@ class ConfigTest(unittest.TestCase):
error_msg = "Top level object in 'override.yml' needs to be an object" error_msg = "Top level object in 'override.yml' needs to be an object"
assert error_msg in exc.exconly() assert error_msg in exc.exconly()
def test_load_with_multiple_files_and_empty_override_v2(self):
base_file = config.ConfigFile(
'base.yml',
{'version': 2, 'services': {'web': {'image': 'example/web'}}})
override_file = config.ConfigFile('override.yml', None)
details = config.ConfigDetails('.', [base_file, override_file])
with pytest.raises(ConfigurationError) as exc:
config.load(details)
error_msg = "Top level object in 'override.yml' needs to be an object"
assert error_msg in exc.exconly()
def test_load_with_multiple_files_and_empty_base(self): def test_load_with_multiple_files_and_empty_base(self):
base_file = config.ConfigFile('base.yml', None) base_file = config.ConfigFile('base.yml', None)
override_file = config.ConfigFile( override_file = config.ConfigFile(
@ -297,6 +309,17 @@ class ConfigTest(unittest.TestCase):
config.load(details) config.load(details)
assert "Top level object in 'base.yml' needs to be an object" in exc.exconly() assert "Top level object in 'base.yml' needs to be an object" in exc.exconly()
def test_load_with_multiple_files_and_empty_base_v2(self):
base_file = config.ConfigFile('base.yml', None)
override_file = config.ConfigFile(
'override.tml',
{'version': 2, 'services': {'web': {'image': 'example/web'}}}
)
details = config.ConfigDetails('.', [base_file, override_file])
with pytest.raises(ConfigurationError) as exc:
config.load(details)
assert "Top level object in 'base.yml' needs to be an object" in exc.exconly()
def test_load_with_multiple_files_and_extends_in_override_file(self): def test_load_with_multiple_files_and_extends_in_override_file(self):
base_file = config.ConfigFile( base_file = config.ConfigFile(
'base.yaml', 'base.yaml',