From 0c87e0b18fa354ddfe1543f5a2391db375fe80be Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Fri, 29 Jan 2016 14:41:18 -0500 Subject: [PATCH] Refactor project network initlization. Signed-off-by: Daniel Nephin --- compose/network.py | 64 ++++++++++++++++++++++++++ compose/project.py | 93 ++++++++------------------------------ tests/unit/project_test.py | 8 ++-- 3 files changed, 88 insertions(+), 77 deletions(-) diff --git a/compose/network.py b/compose/network.py index 4f4f55228..4f4e06b22 100644 --- a/compose/network.py +++ b/compose/network.py @@ -104,3 +104,67 @@ def create_ipam_config_from_dict(ipam_dict): for config in ipam_dict.get('config', []) ], ) + + +def build_networks(name, config_data, client): + network_config = config_data.networks or {} + networks = { + network_name: Network( + client=client, project=name, name=network_name, + driver=data.get('driver'), + driver_opts=data.get('driver_opts'), + ipam=data.get('ipam'), + external_name=data.get('external_name'), + ) + for network_name, data in network_config.items() + } + + if 'default' not in networks: + networks['default'] = Network(client, name, 'default') + + return networks + + +class ProjectNetworks(object): + + def __init__(self, networks, use_networking): + self.networks = networks or {} + self.use_networking = use_networking + + @classmethod + def from_services(cls, services, networks, use_networking): + networks = { + network: networks[network] + for service in services + for network in service.get('networks', ['default']) + } + return cls(networks, use_networking) + + def remove(self): + if not self.use_networking: + return + for network in self.networks.values(): + network.remove() + + def initialize(self): + if not self.use_networking: + return + + for network in self.networks.values(): + network.ensure() + + +def get_networks(service_dict, network_definitions): + if 'network_mode' in service_dict: + return [] + + networks = [] + for name in service_dict.pop('networks', ['default']): + network = network_definitions.get(name) + if network: + networks.append(network.full_name) + else: + raise ConfigurationError( + 'Service "{}" uses an undefined network "{}"' + .format(service_dict['name'], name)) + return networks diff --git a/compose/project.py b/compose/project.py index 6411f7cc3..2e9cfb8f3 100644 --- a/compose/project.py +++ b/compose/project.py @@ -19,7 +19,9 @@ from .const import LABEL_ONE_OFF from .const import LABEL_PROJECT from .const import LABEL_SERVICE from .container import Container -from .network import Network +from .network import build_networks +from .network import get_networks +from .network import ProjectNetworks from .service import ContainerNetworkMode from .service import ConvergenceStrategy from .service import NetworkMode @@ -36,15 +38,12 @@ class Project(object): """ A collection of services. """ - def __init__(self, name, services, client, networks=None, volumes=None, - use_networking=False, network_driver=None): + def __init__(self, name, services, client, networks=None, volumes=None): self.name = name self.services = services self.client = client - self.use_networking = use_networking - self.network_driver = network_driver - self.networks = networks or [] self.volumes = volumes or {} + self.networks = networks or ProjectNetworks({}, False) def labels(self, one_off=False): return [ @@ -58,23 +57,12 @@ class Project(object): Construct a Project from a config.Config object. """ use_networking = (config_data.version and config_data.version != V1) - project = cls(name, [], client, use_networking=use_networking) - - network_config = config_data.networks or {} - custom_networks = [ - Network( - client=client, project=name, name=network_name, - driver=data.get('driver'), - driver_opts=data.get('driver_opts'), - ipam=data.get('ipam'), - external_name=data.get('external_name'), - ) - for network_name, data in network_config.items() - ] - - all_networks = custom_networks[:] - if 'default' not in network_config: - all_networks.append(project.default_network) + networks = build_networks(name, config_data, client) + project_networks = ProjectNetworks.from_services( + config_data.services, + networks, + use_networking) + project = cls(name, [], client, project_networks) if config_data.volumes: for vol_name, data in config_data.volumes.items(): @@ -86,13 +74,15 @@ class Project(object): ) for service_dict in config_data.services: + service_dict = dict(service_dict) if use_networking: - networks = get_networks(service_dict, all_networks) + service_networks = get_networks(service_dict, networks) else: - networks = [] + service_networks = [] + service_dict.pop('networks', None) links = project.get_links(service_dict) - network_mode = project.get_network_mode(service_dict, networks) + network_mode = project.get_network_mode(service_dict, service_networks) volumes_from = get_volumes_from(project, service_dict) if config_data.version != V1: @@ -109,17 +99,13 @@ class Project(object): client=client, project=name, use_networking=use_networking, - networks=networks, + networks=service_networks, links=links, network_mode=network_mode, volumes_from=volumes_from, **service_dict) ) - project.networks += custom_networks - if 'default' not in network_config and project.uses_default_network(): - project.networks.append(project.default_network) - return project @property @@ -201,7 +187,7 @@ class Project(object): def get_network_mode(self, service_dict, networks): network_mode = service_dict.pop('network_mode', None) if not network_mode: - if self.use_networking: + if self.networks.use_networking: return NetworkMode(networks[0]) if networks else NetworkMode('none') return NetworkMode(None) @@ -285,7 +271,7 @@ class Project(object): def down(self, remove_image_type, include_volumes): self.stop() self.remove_stopped(v=include_volumes) - self.remove_networks() + self.networks.remove() if include_volumes: self.remove_volumes() @@ -296,33 +282,10 @@ class Project(object): for service in self.get_services(): service.remove_image(remove_image_type) - def remove_networks(self): - if not self.use_networking: - return - for network in self.networks: - network.remove() - def remove_volumes(self): for volume in self.volumes.values(): volume.remove() - def initialize_networks(self): - if not self.use_networking: - return - - for network in self.networks: - network.ensure() - - def uses_default_network(self): - return any( - self.default_network.full_name in service.networks - for service in self.services - ) - - @property - def default_network(self): - return Network(client=self.client, project=self.name, name='default') - def restart(self, service_names=None, **options): containers = self.containers(service_names, stopped=True) parallel.parallel_restart(containers, options) @@ -392,7 +355,7 @@ class Project(object): plans = self._get_convergence_plans(services, strategy) - self.initialize_networks() + self.networks.initialize() self.initialize_volumes() return [ @@ -465,22 +428,6 @@ class Project(object): return acc + dep_services -def get_networks(service_dict, network_definitions): - if 'network_mode' in service_dict: - return [] - - networks = [] - for name in service_dict.pop('networks', ['default']): - matches = [n for n in network_definitions if n.name == name] - if matches: - networks.append(matches[0].full_name) - else: - raise ConfigurationError( - 'Service "{}" uses an undefined network "{}"' - .format(service_dict['name'], name)) - return networks - - def get_volumes_from(project, service_dict): volumes_from = service_dict.pop('volumes_from', None) if not volumes_from: diff --git a/tests/unit/project_test.py b/tests/unit/project_test.py index 21c6be475..bec238de6 100644 --- a/tests/unit/project_test.py +++ b/tests/unit/project_test.py @@ -45,7 +45,7 @@ class ProjectTest(unittest.TestCase): self.assertEqual(project.get_service('web').options['image'], 'busybox:latest') self.assertEqual(project.get_service('db').name, 'db') self.assertEqual(project.get_service('db').options['image'], 'busybox:latest') - self.assertFalse(project.use_networking) + self.assertFalse(project.networks.use_networking) def test_from_config_v2(self): config = Config( @@ -65,7 +65,7 @@ class ProjectTest(unittest.TestCase): ) project = Project.from_config('composetest', config, None) self.assertEqual(len(project.services), 2) - self.assertTrue(project.use_networking) + self.assertTrue(project.networks.use_networking) def test_get_service(self): web = Service( @@ -426,7 +426,7 @@ class ProjectTest(unittest.TestCase): ), ) - assert project.uses_default_network() + assert 'default' in project.networks.networks def test_uses_default_network_false(self): project = Project.from_config( @@ -446,7 +446,7 @@ class ProjectTest(unittest.TestCase): ), ) - assert not project.uses_default_network() + assert 'default' not in project.networks.networks def test_container_without_name(self): self.mock_client.containers.return_value = [