From 187ad4ce26401aaa10984c3c9a9782d6b2efdb87 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Thu, 3 Sep 2015 13:02:46 -0400 Subject: [PATCH] Refactor network_mode logic out of Service. Signed-off-by: Daniel Nephin --- compose/project.py | 11 +++-- compose/service.py | 88 +++++++++++++++++++++++++------------- tests/unit/project_test.py | 6 +-- tests/unit/service_test.py | 48 ++++++++++++++++++++- 4 files changed, 116 insertions(+), 37 deletions(-) diff --git a/compose/project.py b/compose/project.py index 54d6c4434..8db20e766 100644 --- a/compose/project.py +++ b/compose/project.py @@ -14,7 +14,10 @@ from .const import LABEL_PROJECT from .const import LABEL_SERVICE from .container import Container from .legacy import check_for_legacy_containers +from .service import ContainerNet +from .service import Net from .service import Service +from .service import ServiceNet from .utils import parallel_execute @@ -192,18 +195,18 @@ class Project(object): def get_net(self, service_dict): net = service_dict.pop('net', None) if not net: - return + return Net(None) net_name = get_service_name_from_net(net) if not net_name: - return net + return Net(net) try: - return self.get_service(net_name) + return ServiceNet(self.get_service(net_name)) except NoSuchService: pass try: - return Container.from_id(self.client, net_name) + return ContainerNet(Container.from_id(self.client, net_name)) except APIError: raise ConfigurationError( 'Service "%s" is trying to use the network of "%s", ' diff --git a/compose/service.py b/compose/service.py index f60d57bfd..bfc6f904e 100644 --- a/compose/service.py +++ b/compose/service.py @@ -105,7 +105,7 @@ class Service(object): self.project = project self.links = links or [] self.volumes_from = volumes_from or [] - self.net = net or None + self.net = net or Net(None) self.options = options def containers(self, stopped=False, one_off=False, filters={}): @@ -489,12 +489,12 @@ class Service(object): 'options': self.options, 'image_id': self.image()['Id'], 'links': [(service.name, alias) for service, alias in self.links], - 'net': self.get_net_name() or getattr(self.net, 'id', self.net), + 'net': self.net.id, 'volumes_from': self.get_volumes_from_names(), } def get_dependency_names(self): - net_name = self.get_net_name() + net_name = self.net.service_name return (self.get_linked_names() + self.get_volumes_from_names() + ([net_name] if net_name else [])) @@ -505,12 +505,6 @@ class Service(object): def get_volumes_from_names(self): return [s.name for s in self.volumes_from if isinstance(s, Service)] - def get_net_name(self): - if isinstance(self.net, Service): - return self.net.name - else: - return - def get_container_name(self, number, one_off=False): # TODO: Implement issue #652 here return build_container_name(self.project, self.name, number, one_off) @@ -562,25 +556,6 @@ class Service(object): return volumes_from - def _get_net(self): - if not self.net: - return None - - if isinstance(self.net, Service): - containers = self.net.containers() - if len(containers) > 0: - net = 'container:' + containers[0].id - else: - log.warning("Warning: Service %s is trying to use reuse the network stack " - "of another service that is not running." % (self.net.name)) - net = None - elif isinstance(self.net, Container): - net = 'container:' + self.net.id - else: - net = self.net - - return net - def _get_container_create_options( self, override_options, @@ -694,7 +669,7 @@ class Service(object): binds=options.get('binds'), volumes_from=self._get_volumes_from(), privileged=privileged, - network_mode=self._get_net(), + network_mode=self.net.mode, devices=devices, dns=dns, dns_search=dns_search, @@ -793,6 +768,61 @@ class Service(object): stream_output(output, sys.stdout) +class Net(object): + """A `standard` network mode (ex: host, bridge)""" + + service_name = None + + def __init__(self, net): + self.net = net + + @property + def id(self): + return self.net + + mode = id + + +class ContainerNet(object): + """A network mode that uses a containers network stack.""" + + service_name = None + + def __init__(self, container): + self.container = container + + @property + def id(self): + return self.container.id + + @property + def mode(self): + return 'container:' + self.container.id + + +class ServiceNet(object): + """A network mode that uses a service's network stack.""" + + def __init__(self, service): + self.service = service + + @property + def id(self): + return self.service.name + + service_name = id + + @property + def mode(self): + containers = self.service.containers() + if containers: + return 'container:' + containers[0].id + + log.warn("Warning: Service %s is trying to use reuse the network stack " + "of another service that is not running." % (self.id)) + return None + + # Names diff --git a/tests/unit/project_test.py b/tests/unit/project_test.py index 37ebe5148..ce74eb30b 100644 --- a/tests/unit/project_test.py +++ b/tests/unit/project_test.py @@ -221,7 +221,7 @@ class ProjectTest(unittest.TestCase): } ], self.mock_client) service = project.get_service('test') - self.assertEqual(service._get_net(), None) + self.assertEqual(service.net.id, None) self.assertNotIn('NetworkMode', service._get_container_host_config({})) def test_use_net_from_container(self): @@ -236,7 +236,7 @@ class ProjectTest(unittest.TestCase): } ], self.mock_client) service = project.get_service('test') - self.assertEqual(service._get_net(), 'container:' + container_id) + self.assertEqual(service.net.mode, 'container:' + container_id) def test_use_net_from_service(self): container_name = 'test_aaa_1' @@ -261,7 +261,7 @@ class ProjectTest(unittest.TestCase): ], self.mock_client) service = project.get_service('test') - self.assertEqual(service._get_net(), 'container:' + container_name) + self.assertEqual(service.net.mode, 'container:' + container_name) def test_container_without_name(self): self.mock_client.containers.return_value = [ diff --git a/tests/unit/service_test.py b/tests/unit/service_test.py index 3981cad20..de973339b 100644 --- a/tests/unit/service_test.py +++ b/tests/unit/service_test.py @@ -13,13 +13,16 @@ from compose.const import LABEL_SERVICE from compose.container import Container from compose.service import build_volume_binding from compose.service import ConfigError +from compose.service import ContainerNet from compose.service import get_container_data_volumes from compose.service import merge_volume_bindings from compose.service import NeedsBuildError +from compose.service import Net from compose.service import NoSuchImageError from compose.service import parse_repository_tag from compose.service import parse_volume_spec from compose.service import Service +from compose.service import ServiceNet class ServiceTest(unittest.TestCase): @@ -337,7 +340,7 @@ class ServiceTest(unittest.TestCase): 'foo', image='example.com/foo', client=self.mock_client, - net=Service('other'), + net=ServiceNet(Service('other')), links=[(Service('one'), 'one')], volumes_from=[Service('two')]) @@ -373,6 +376,49 @@ class ServiceTest(unittest.TestCase): self.assertEqual(config_dict, expected) +class NetTestCase(unittest.TestCase): + + def test_net(self): + net = Net('host') + self.assertEqual(net.id, 'host') + self.assertEqual(net.mode, 'host') + self.assertEqual(net.service_name, None) + + def test_net_container(self): + container_id = 'abcd' + net = ContainerNet(Container(None, {'Id': container_id})) + self.assertEqual(net.id, container_id) + self.assertEqual(net.mode, 'container:' + container_id) + self.assertEqual(net.service_name, None) + + def test_net_service(self): + container_id = 'bbbb' + service_name = 'web' + mock_client = mock.create_autospec(docker.Client) + mock_client.containers.return_value = [ + {'Id': container_id, 'Name': container_id, 'Image': 'abcd'}, + ] + + service = Service(name=service_name, client=mock_client) + net = ServiceNet(service) + + self.assertEqual(net.id, service_name) + self.assertEqual(net.mode, 'container:' + container_id) + self.assertEqual(net.service_name, service_name) + + def test_net_service_no_containers(self): + service_name = 'web' + mock_client = mock.create_autospec(docker.Client) + mock_client.containers.return_value = [] + + service = Service(name=service_name, client=mock_client) + net = ServiceNet(service) + + self.assertEqual(net.id, service_name) + self.assertEqual(net.mode, None) + self.assertEqual(net.service_name, service_name) + + def mock_get_image(images): if images: return images[0]