Extract link names into a function.

Signed-off-by: Daniel Nephin <dnephin@docker.com>
This commit is contained in:
Daniel Nephin 2015-09-03 13:13:22 -04:00
parent 805f6a7683
commit db31adc208
4 changed files with 20 additions and 18 deletions

View File

@ -304,7 +304,7 @@ class TopLevelCommand(Command):
log.warn(INSECURE_SSL_WARNING) log.warn(INSECURE_SSL_WARNING)
if not options['--no-deps']: if not options['--no-deps']:
deps = service.get_linked_names() deps = service.get_linked_service_names()
if len(deps) > 0: if len(deps) > 0:
project.up( project.up(

View File

@ -477,19 +477,22 @@ class Service(object):
return { return {
'options': self.options, 'options': self.options,
'image_id': self.image()['Id'], 'image_id': self.image()['Id'],
'links': [(service.name, alias) for service, alias in self.links], 'links': self.get_link_names(),
'net': self.net.id, 'net': self.net.id,
'volumes_from': self.get_volumes_from_names(), 'volumes_from': self.get_volumes_from_names(),
} }
def get_dependency_names(self): def get_dependency_names(self):
net_name = self.net.service_name net_name = self.net.service_name
return (self.get_linked_names() + return (self.get_linked_service_names() +
self.get_volumes_from_names() + self.get_volumes_from_names() +
([net_name] if net_name else [])) ([net_name] if net_name else []))
def get_linked_names(self): def get_linked_service_names(self):
return [s.name for (s, _) in self.links] return [service.name for (service, _) in self.links]
def get_link_names(self):
return [(service.name, alias) for service, alias in self.links]
def get_volumes_from_names(self): def get_volumes_from_names(self):
return [s.name for s in self.volumes_from if isinstance(s, Service)] return [s.name for s in self.volumes_from if isinstance(s, Service)]
@ -776,7 +779,7 @@ class Net(object):
class ContainerNet(object): class ContainerNet(object):
"""A network mode that uses a containers network stack.""" """A network mode that uses a container's network stack."""
service_name = None service_name = None

View File

@ -112,7 +112,7 @@ class ProjectTest(DockerClientTestCase):
web = project.get_service('web') web = project.get_service('web')
net = project.get_service('net') net = project.get_service('net')
self.assertEqual(web._get_net(), 'container:' + net.containers()[0].id) self.assertEqual(web.net.mode, 'container:' + net.containers()[0].id)
def test_net_from_container(self): def test_net_from_container(self):
net_container = Container.create( net_container = Container.create(
@ -138,7 +138,7 @@ class ProjectTest(DockerClientTestCase):
project.up() project.up()
web = project.get_service('web') web = project.get_service('web')
self.assertEqual(web._get_net(), 'container:' + net_container.id) self.assertEqual(web.net.mode, 'container:' + net_container.id)
def test_start_stop_kill_remove(self): def test_start_stop_kill_remove(self):
web = self.create_service('web') web = self.create_service('web')

View File

@ -9,6 +9,7 @@ import tempfile
import shutil import shutil
from six import StringIO, text_type from six import StringIO, text_type
from .testcases import DockerClientTestCase
from compose import __version__ from compose import __version__
from compose.const import ( from compose.const import (
LABEL_CONTAINER_NUMBER, LABEL_CONTAINER_NUMBER,
@ -17,14 +18,12 @@ from compose.const import (
LABEL_SERVICE, LABEL_SERVICE,
LABEL_VERSION, LABEL_VERSION,
) )
from compose.service import (
ConfigError,
ConvergencePlan,
Service,
build_extra_hosts,
)
from compose.container import Container from compose.container import Container
from .testcases import DockerClientTestCase from compose.service import build_extra_hosts
from compose.service import ConfigError
from compose.service import ConvergencePlan
from compose.service import Net
from compose.service import Service
def create_and_start_container(service, **override_options): def create_and_start_container(service, **override_options):
@ -743,17 +742,17 @@ class ServiceTest(DockerClientTestCase):
self.assertEqual(list(container.inspect()['HostConfig']['PortBindings'].keys()), ['8000/tcp']) self.assertEqual(list(container.inspect()['HostConfig']['PortBindings'].keys()), ['8000/tcp'])
def test_network_mode_none(self): def test_network_mode_none(self):
service = self.create_service('web', net='none') service = self.create_service('web', net=Net('none'))
container = create_and_start_container(service) container = create_and_start_container(service)
self.assertEqual(container.get('HostConfig.NetworkMode'), 'none') self.assertEqual(container.get('HostConfig.NetworkMode'), 'none')
def test_network_mode_bridged(self): def test_network_mode_bridged(self):
service = self.create_service('web', net='bridge') service = self.create_service('web', net=Net('bridge'))
container = create_and_start_container(service) container = create_and_start_container(service)
self.assertEqual(container.get('HostConfig.NetworkMode'), 'bridge') self.assertEqual(container.get('HostConfig.NetworkMode'), 'bridge')
def test_network_mode_host(self): def test_network_mode_host(self):
service = self.create_service('web', net='host') service = self.create_service('web', net=Net('host'))
container = create_and_start_container(service) container = create_and_start_container(service)
self.assertEqual(container.get('HostConfig.NetworkMode'), 'host') self.assertEqual(container.get('HostConfig.NetworkMode'), 'host')