Refactor project network initlization.

Signed-off-by: Daniel Nephin <dnephin@docker.com>
This commit is contained in:
Daniel Nephin 2016-01-29 14:41:18 -05:00 committed by Aanand Prasad
parent 8024f2f09e
commit 0c87e0b18f
3 changed files with 88 additions and 77 deletions

View File

@ -104,3 +104,67 @@ def create_ipam_config_from_dict(ipam_dict):
for config in ipam_dict.get('config', []) for config in ipam_dict.get('config', [])
], ],
) )
def build_networks(name, config_data, client):
network_config = config_data.networks or {}
networks = {
network_name: Network(
client=client, project=name, name=network_name,
driver=data.get('driver'),
driver_opts=data.get('driver_opts'),
ipam=data.get('ipam'),
external_name=data.get('external_name'),
)
for network_name, data in network_config.items()
}
if 'default' not in networks:
networks['default'] = Network(client, name, 'default')
return networks
class ProjectNetworks(object):
def __init__(self, networks, use_networking):
self.networks = networks or {}
self.use_networking = use_networking
@classmethod
def from_services(cls, services, networks, use_networking):
networks = {
network: networks[network]
for service in services
for network in service.get('networks', ['default'])
}
return cls(networks, use_networking)
def remove(self):
if not self.use_networking:
return
for network in self.networks.values():
network.remove()
def initialize(self):
if not self.use_networking:
return
for network in self.networks.values():
network.ensure()
def get_networks(service_dict, network_definitions):
if 'network_mode' in service_dict:
return []
networks = []
for name in service_dict.pop('networks', ['default']):
network = network_definitions.get(name)
if network:
networks.append(network.full_name)
else:
raise ConfigurationError(
'Service "{}" uses an undefined network "{}"'
.format(service_dict['name'], name))
return networks

View File

@ -19,7 +19,9 @@ from .const import LABEL_ONE_OFF
from .const import LABEL_PROJECT from .const import LABEL_PROJECT
from .const import LABEL_SERVICE from .const import LABEL_SERVICE
from .container import Container from .container import Container
from .network import Network from .network import build_networks
from .network import get_networks
from .network import ProjectNetworks
from .service import ContainerNetworkMode from .service import ContainerNetworkMode
from .service import ConvergenceStrategy from .service import ConvergenceStrategy
from .service import NetworkMode from .service import NetworkMode
@ -36,15 +38,12 @@ class Project(object):
""" """
A collection of services. A collection of services.
""" """
def __init__(self, name, services, client, networks=None, volumes=None, def __init__(self, name, services, client, networks=None, volumes=None):
use_networking=False, network_driver=None):
self.name = name self.name = name
self.services = services self.services = services
self.client = client self.client = client
self.use_networking = use_networking
self.network_driver = network_driver
self.networks = networks or []
self.volumes = volumes or {} self.volumes = volumes or {}
self.networks = networks or ProjectNetworks({}, False)
def labels(self, one_off=False): def labels(self, one_off=False):
return [ return [
@ -58,23 +57,12 @@ class Project(object):
Construct a Project from a config.Config object. Construct a Project from a config.Config object.
""" """
use_networking = (config_data.version and config_data.version != V1) use_networking = (config_data.version and config_data.version != V1)
project = cls(name, [], client, use_networking=use_networking) networks = build_networks(name, config_data, client)
project_networks = ProjectNetworks.from_services(
network_config = config_data.networks or {} config_data.services,
custom_networks = [ networks,
Network( use_networking)
client=client, project=name, name=network_name, project = cls(name, [], client, project_networks)
driver=data.get('driver'),
driver_opts=data.get('driver_opts'),
ipam=data.get('ipam'),
external_name=data.get('external_name'),
)
for network_name, data in network_config.items()
]
all_networks = custom_networks[:]
if 'default' not in network_config:
all_networks.append(project.default_network)
if config_data.volumes: if config_data.volumes:
for vol_name, data in config_data.volumes.items(): for vol_name, data in config_data.volumes.items():
@ -86,13 +74,15 @@ class Project(object):
) )
for service_dict in config_data.services: for service_dict in config_data.services:
service_dict = dict(service_dict)
if use_networking: if use_networking:
networks = get_networks(service_dict, all_networks) service_networks = get_networks(service_dict, networks)
else: else:
networks = [] service_networks = []
service_dict.pop('networks', None)
links = project.get_links(service_dict) links = project.get_links(service_dict)
network_mode = project.get_network_mode(service_dict, networks) network_mode = project.get_network_mode(service_dict, service_networks)
volumes_from = get_volumes_from(project, service_dict) volumes_from = get_volumes_from(project, service_dict)
if config_data.version != V1: if config_data.version != V1:
@ -109,17 +99,13 @@ class Project(object):
client=client, client=client,
project=name, project=name,
use_networking=use_networking, use_networking=use_networking,
networks=networks, networks=service_networks,
links=links, links=links,
network_mode=network_mode, network_mode=network_mode,
volumes_from=volumes_from, volumes_from=volumes_from,
**service_dict) **service_dict)
) )
project.networks += custom_networks
if 'default' not in network_config and project.uses_default_network():
project.networks.append(project.default_network)
return project return project
@property @property
@ -201,7 +187,7 @@ class Project(object):
def get_network_mode(self, service_dict, networks): def get_network_mode(self, service_dict, networks):
network_mode = service_dict.pop('network_mode', None) network_mode = service_dict.pop('network_mode', None)
if not network_mode: if not network_mode:
if self.use_networking: if self.networks.use_networking:
return NetworkMode(networks[0]) if networks else NetworkMode('none') return NetworkMode(networks[0]) if networks else NetworkMode('none')
return NetworkMode(None) return NetworkMode(None)
@ -285,7 +271,7 @@ class Project(object):
def down(self, remove_image_type, include_volumes): def down(self, remove_image_type, include_volumes):
self.stop() self.stop()
self.remove_stopped(v=include_volumes) self.remove_stopped(v=include_volumes)
self.remove_networks() self.networks.remove()
if include_volumes: if include_volumes:
self.remove_volumes() self.remove_volumes()
@ -296,33 +282,10 @@ class Project(object):
for service in self.get_services(): for service in self.get_services():
service.remove_image(remove_image_type) service.remove_image(remove_image_type)
def remove_networks(self):
if not self.use_networking:
return
for network in self.networks:
network.remove()
def remove_volumes(self): def remove_volumes(self):
for volume in self.volumes.values(): for volume in self.volumes.values():
volume.remove() volume.remove()
def initialize_networks(self):
if not self.use_networking:
return
for network in self.networks:
network.ensure()
def uses_default_network(self):
return any(
self.default_network.full_name in service.networks
for service in self.services
)
@property
def default_network(self):
return Network(client=self.client, project=self.name, name='default')
def restart(self, service_names=None, **options): def restart(self, service_names=None, **options):
containers = self.containers(service_names, stopped=True) containers = self.containers(service_names, stopped=True)
parallel.parallel_restart(containers, options) parallel.parallel_restart(containers, options)
@ -392,7 +355,7 @@ class Project(object):
plans = self._get_convergence_plans(services, strategy) plans = self._get_convergence_plans(services, strategy)
self.initialize_networks() self.networks.initialize()
self.initialize_volumes() self.initialize_volumes()
return [ return [
@ -465,22 +428,6 @@ class Project(object):
return acc + dep_services return acc + dep_services
def get_networks(service_dict, network_definitions):
if 'network_mode' in service_dict:
return []
networks = []
for name in service_dict.pop('networks', ['default']):
matches = [n for n in network_definitions if n.name == name]
if matches:
networks.append(matches[0].full_name)
else:
raise ConfigurationError(
'Service "{}" uses an undefined network "{}"'
.format(service_dict['name'], name))
return networks
def get_volumes_from(project, service_dict): def get_volumes_from(project, service_dict):
volumes_from = service_dict.pop('volumes_from', None) volumes_from = service_dict.pop('volumes_from', None)
if not volumes_from: if not volumes_from:

View File

@ -45,7 +45,7 @@ class ProjectTest(unittest.TestCase):
self.assertEqual(project.get_service('web').options['image'], 'busybox:latest') self.assertEqual(project.get_service('web').options['image'], 'busybox:latest')
self.assertEqual(project.get_service('db').name, 'db') self.assertEqual(project.get_service('db').name, 'db')
self.assertEqual(project.get_service('db').options['image'], 'busybox:latest') self.assertEqual(project.get_service('db').options['image'], 'busybox:latest')
self.assertFalse(project.use_networking) self.assertFalse(project.networks.use_networking)
def test_from_config_v2(self): def test_from_config_v2(self):
config = Config( config = Config(
@ -65,7 +65,7 @@ class ProjectTest(unittest.TestCase):
) )
project = Project.from_config('composetest', config, None) project = Project.from_config('composetest', config, None)
self.assertEqual(len(project.services), 2) self.assertEqual(len(project.services), 2)
self.assertTrue(project.use_networking) self.assertTrue(project.networks.use_networking)
def test_get_service(self): def test_get_service(self):
web = Service( web = Service(
@ -426,7 +426,7 @@ class ProjectTest(unittest.TestCase):
), ),
) )
assert project.uses_default_network() assert 'default' in project.networks.networks
def test_uses_default_network_false(self): def test_uses_default_network_false(self):
project = Project.from_config( project = Project.from_config(
@ -446,7 +446,7 @@ class ProjectTest(unittest.TestCase):
), ),
) )
assert not project.uses_default_network() assert 'default' not in project.networks.networks
def test_container_without_name(self): def test_container_without_name(self):
self.mock_client.containers.return_value = [ self.mock_client.containers.return_value = [