Merge pull request #1076 from gilclark/master

Make volumes_from and net containers first class dependencies
This commit is contained in:
Aanand Prasad 2015-03-13 11:46:13 +00:00
commit 25c70c2af4
6 changed files with 429 additions and 41 deletions

View File

@ -293,7 +293,7 @@ class TopLevelCommand(Command):
if len(deps) > 0: if len(deps) > 0:
project.up( project.up(
service_names=deps, service_names=deps,
start_links=True, start_deps=True,
recreate=False, recreate=False,
insecure_registry=insecure_registry, insecure_registry=insecure_registry,
detach=options['-d'] detach=options['-d']
@ -435,13 +435,13 @@ class TopLevelCommand(Command):
monochrome = options['--no-color'] monochrome = options['--no-color']
start_links = not options['--no-deps'] start_deps = not options['--no-deps']
recreate = not options['--no-recreate'] recreate = not options['--no-recreate']
service_names = options['SERVICE'] service_names = options['SERVICE']
project.up( project.up(
service_names=service_names, service_names=service_names,
start_links=start_links, start_deps=start_deps,
recreate=recreate, recreate=recreate,
insecure_registry=insecure_registry, insecure_registry=insecure_registry,
detach=options['-d'], detach=options['-d'],

View File

@ -10,6 +10,17 @@ from docker.errors import APIError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def get_service_name_from_net(net_config):
if not net_config:
return
if not net_config.startswith('container:'):
return
_, net_name = net_config.split(':', 1)
return net_name
def sort_service_dicts(services): def sort_service_dicts(services):
# Topological sort (Cormen/Tarjan algorithm). # Topological sort (Cormen/Tarjan algorithm).
unmarked = services[:] unmarked = services[:]
@ -19,6 +30,15 @@ def sort_service_dicts(services):
def get_service_names(links): def get_service_names(links):
return [link.split(':')[0] for link in links] return [link.split(':')[0] for link in links]
def get_service_dependents(service_dict, services):
name = service_dict['name']
return [
service for service in services
if (name in get_service_names(service.get('links', [])) or
name in service.get('volumes_from', []) or
name == get_service_name_from_net(service.get('net')))
]
def visit(n): def visit(n):
if n['name'] in temporary_marked: if n['name'] in temporary_marked:
if n['name'] in get_service_names(n.get('links', [])): if n['name'] in get_service_names(n.get('links', [])):
@ -29,8 +49,7 @@ def sort_service_dicts(services):
raise DependencyError('Circular import between %s' % ' and '.join(temporary_marked)) raise DependencyError('Circular import between %s' % ' and '.join(temporary_marked))
if n in unmarked: if n in unmarked:
temporary_marked.add(n['name']) temporary_marked.add(n['name'])
dependents = [m for m in services if (n['name'] in get_service_names(m.get('links', []))) or (n['name'] in m.get('volumes_from', []))] for m in get_service_dependents(n, services):
for m in dependents:
visit(m) visit(m)
temporary_marked.remove(n['name']) temporary_marked.remove(n['name'])
unmarked.remove(n) unmarked.remove(n)
@ -60,8 +79,10 @@ class Project(object):
for service_dict in sort_service_dicts(service_dicts): for service_dict in sort_service_dicts(service_dicts):
links = project.get_links(service_dict) links = project.get_links(service_dict)
volumes_from = project.get_volumes_from(service_dict) volumes_from = project.get_volumes_from(service_dict)
net = project.get_net(service_dict)
project.services.append(Service(client=client, project=name, links=links, volumes_from=volumes_from, **service_dict)) project.services.append(Service(client=client, project=name, links=links, net=net,
volumes_from=volumes_from, **service_dict))
return project return project
@classmethod @classmethod
@ -85,31 +106,31 @@ class Project(object):
raise NoSuchService(name) raise NoSuchService(name)
def get_services(self, service_names=None, include_links=False): def get_services(self, service_names=None, include_deps=False):
""" """
Returns a list of this project's services filtered Returns a list of this project's services filtered
by the provided list of names, or all services if service_names is None by the provided list of names, or all services if service_names is None
or []. or [].
If include_links is specified, returns a list including the links for If include_deps is specified, returns a list including the dependencies for
service_names, in order of dependency. service_names, in order of dependency.
Preserves the original order of self.services where possible, Preserves the original order of self.services where possible,
reordering as needed to resolve links. reordering as needed to resolve dependencies.
Raises NoSuchService if any of the named services do not exist. Raises NoSuchService if any of the named services do not exist.
""" """
if service_names is None or len(service_names) == 0: if service_names is None or len(service_names) == 0:
return self.get_services( return self.get_services(
service_names=[s.name for s in self.services], service_names=[s.name for s in self.services],
include_links=include_links include_deps=include_deps
) )
else: else:
unsorted = [self.get_service(name) for name in service_names] unsorted = [self.get_service(name) for name in service_names]
services = [s for s in self.services if s in unsorted] services = [s for s in self.services if s in unsorted]
if include_links: if include_deps:
services = reduce(self._inject_links, services, []) services = reduce(self._inject_deps, services, [])
uniques = [] uniques = []
[uniques.append(s) for s in services if s not in uniques] [uniques.append(s) for s in services if s not in uniques]
@ -146,6 +167,28 @@ class Project(object):
del service_dict['volumes_from'] del service_dict['volumes_from']
return volumes_from return volumes_from
def get_net(self, service_dict):
if 'net' in service_dict:
net_name = get_service_name_from_net(service_dict.get('net'))
if net_name:
try:
net = self.get_service(net_name)
except NoSuchService:
try:
net = Container.from_id(self.client, net_name)
except APIError:
raise ConfigurationError('Serivce "%s" is trying to use the network of "%s", which is not the name of a service or container.' % (service_dict['name'], net_name))
else:
net = service_dict['net']
del service_dict['net']
else:
net = 'bridge'
return net
def start(self, service_names=None, **options): def start(self, service_names=None, **options):
for service in self.get_services(service_names): for service in self.get_services(service_names):
service.start(**options) service.start(**options)
@ -171,13 +214,13 @@ class Project(object):
def up(self, def up(self,
service_names=None, service_names=None,
start_links=True, start_deps=True,
recreate=True, recreate=True,
insecure_registry=False, insecure_registry=False,
detach=False, detach=False,
do_build=True): do_build=True):
running_containers = [] running_containers = []
for service in self.get_services(service_names, include_links=start_links): for service in self.get_services(service_names, include_deps=start_deps):
if recreate: if recreate:
for (_, container) in service.recreate_containers( for (_, container) in service.recreate_containers(
insecure_registry=insecure_registry, insecure_registry=insecure_registry,
@ -194,7 +237,7 @@ class Project(object):
return running_containers return running_containers
def pull(self, service_names=None, insecure_registry=False): def pull(self, service_names=None, insecure_registry=False):
for service in self.get_services(service_names, include_links=True): for service in self.get_services(service_names, include_deps=True):
service.pull(insecure_registry=insecure_registry) service.pull(insecure_registry=insecure_registry)
def remove_stopped(self, service_names=None, **options): def remove_stopped(self, service_names=None, **options):
@ -207,19 +250,22 @@ class Project(object):
for service in self.get_services(service_names) for service in self.get_services(service_names)
if service.has_container(container, one_off=one_off)] if service.has_container(container, one_off=one_off)]
def _inject_links(self, acc, service): def _inject_deps(self, acc, service):
linked_names = service.get_linked_names() net_name = service.get_net_name()
dep_names = (service.get_linked_names() +
service.get_volumes_from_names() +
([net_name] if net_name else []))
if len(linked_names) > 0: if len(dep_names) > 0:
linked_services = self.get_services( dep_services = self.get_services(
service_names=linked_names, service_names=list(set(dep_names)),
include_links=True include_deps=True
) )
else: else:
linked_services = [] dep_services = []
linked_services.append(service) dep_services.append(service)
return acc + linked_services return acc + dep_services
class NoSuchService(Exception): class NoSuchService(Exception):

View File

@ -88,7 +88,7 @@ ServiceName = namedtuple('ServiceName', 'project service number')
class Service(object): class Service(object):
def __init__(self, name, client=None, project='default', links=None, external_links=None, volumes_from=None, **options): def __init__(self, name, client=None, project='default', links=None, external_links=None, volumes_from=None, net=None, **options):
if not re.match('^%s+$' % VALID_NAME_CHARS, name): if not re.match('^%s+$' % VALID_NAME_CHARS, name):
raise ConfigError('Invalid service name "%s" - only %s are allowed' % (name, VALID_NAME_CHARS)) raise ConfigError('Invalid service name "%s" - only %s are allowed' % (name, VALID_NAME_CHARS))
if not re.match('^%s+$' % VALID_NAME_CHARS, project): if not re.match('^%s+$' % VALID_NAME_CHARS, project):
@ -116,6 +116,7 @@ class Service(object):
self.links = links or [] self.links = links or []
self.external_links = external_links or [] self.external_links = external_links or []
self.volumes_from = volumes_from or [] self.volumes_from = volumes_from or []
self.net = net or None
self.options = options self.options = options
def containers(self, stopped=False, one_off=False): def containers(self, stopped=False, one_off=False):
@ -320,7 +321,6 @@ class Service(object):
if ':' in volume) if ':' in volume)
privileged = options.get('privileged', False) privileged = options.get('privileged', False)
net = options.get('net', 'bridge')
dns = options.get('dns', None) dns = options.get('dns', None)
dns_search = options.get('dns_search', None) dns_search = options.get('dns_search', None)
cap_add = options.get('cap_add', None) cap_add = options.get('cap_add', None)
@ -334,7 +334,7 @@ class Service(object):
binds=volume_bindings, binds=volume_bindings,
volumes_from=self._get_volumes_from(intermediate_container), volumes_from=self._get_volumes_from(intermediate_container),
privileged=privileged, privileged=privileged,
network_mode=net, network_mode=self._get_net(),
dns=dns, dns=dns,
dns_search=dns_search, dns_search=dns_search,
restart_policy=restart, restart_policy=restart,
@ -364,6 +364,15 @@ class Service(object):
def get_linked_names(self): def get_linked_names(self):
return [s.name for (s, _) in self.links] return [s.name for (s, _) in self.links]
def get_volumes_from_names(self):
return [s.name for s in self.volumes_from if isinstance(s, Service)]
def get_net_name(self):
if isinstance(self.net, Service):
return self.net.name
else:
return
def _next_container_name(self, all_containers, one_off=False): def _next_container_name(self, all_containers, one_off=False):
bits = [self.project, self.name] bits = [self.project, self.name]
if one_off: if one_off:
@ -399,7 +408,6 @@ class Service(object):
for volume_source in self.volumes_from: for volume_source in self.volumes_from:
if isinstance(volume_source, Service): if isinstance(volume_source, Service):
containers = volume_source.containers(stopped=True) containers = volume_source.containers(stopped=True)
if not containers: if not containers:
volumes_from.append(volume_source.create_container().id) volumes_from.append(volume_source.create_container().id)
else: else:
@ -413,6 +421,25 @@ class Service(object):
return volumes_from return volumes_from
def _get_net(self):
if not self.net:
return "bridge"
if isinstance(self.net, Service):
containers = self.net.containers()
if len(containers) > 0:
net = 'container:' + containers[0].id
else:
log.warning("Warning: Service %s is trying to use reuse the network stack "
"of another service that is not running." % (self.net.name))
net = None
elif isinstance(self.net, Container):
net = 'container:' + self.net.id
else:
net = self.net
return net
def _get_container_create_options(self, override_options, one_off=False): def _get_container_create_options(self, override_options, one_off=False):
container_options = dict( container_options = dict(
(k, self.options[k]) (k, self.options[k])

View File

@ -44,6 +44,63 @@ class ProjectTest(DockerClientTestCase):
db = project.get_service('db') db = project.get_service('db')
self.assertEqual(db.volumes_from, [data_container]) self.assertEqual(db.volumes_from, [data_container])
project.kill()
project.remove_stopped()
def test_net_from_service(self):
project = Project.from_config(
name='composetest',
config={
'net': {
'image': 'busybox:latest',
'command': ["/bin/sleep", "300"]
},
'web': {
'image': 'busybox:latest',
'net': 'container:net',
'command': ["/bin/sleep", "300"]
},
},
client=self.client,
)
project.up()
web = project.get_service('web')
net = project.get_service('net')
self.assertEqual(web._get_net(), 'container:'+net.containers()[0].id)
project.kill()
project.remove_stopped()
def test_net_from_container(self):
net_container = Container.create(
self.client,
image='busybox:latest',
name='composetest_net_container',
command='/bin/sleep 300'
)
net_container.start()
project = Project.from_config(
name='composetest',
config={
'web': {
'image': 'busybox:latest',
'net': 'container:composetest_net_container'
},
},
client=self.client,
)
project.up()
web = project.get_service('web')
self.assertEqual(web._get_net(), 'container:'+net_container.id)
project.kill()
project.remove_stopped()
def test_start_stop_kill_remove(self): def test_start_stop_kill_remove(self):
web = self.create_service('web') web = self.create_service('web')
db = self.create_service('db') db = self.create_service('db')
@ -199,20 +256,86 @@ class ProjectTest(DockerClientTestCase):
project.kill() project.kill()
project.remove_stopped() project.remove_stopped()
def test_project_up_with_no_deps(self): def test_project_up_starts_depends(self):
console = self.create_service('console') project = Project.from_config(
db = self.create_service('db', volumes=['/var/db']) name='composetest',
web = self.create_service('web', links=[(db, 'db')]) config={
'console': {
project = Project('composetest', [web, db, console], self.client) 'image': 'busybox:latest',
'command': ["/bin/sleep", "300"],
},
'net' : {
'image': 'busybox:latest',
'command': ["/bin/sleep", "300"]
},
'app': {
'image': 'busybox:latest',
'command': ["/bin/sleep", "300"],
'net': 'container:net'
},
'web': {
'image': 'busybox:latest',
'command': ["/bin/sleep", "300"],
'net': 'container:net',
'links': ['app']
},
},
client=self.client,
)
project.start() project.start()
self.assertEqual(len(project.containers()), 0) self.assertEqual(len(project.containers()), 0)
project.up(['web'], start_links=False) project.up(['web'])
self.assertEqual(len(project.containers()), 1) self.assertEqual(len(project.containers()), 3)
self.assertEqual(len(web.containers()), 1) self.assertEqual(len(project.get_service('web').containers()), 1)
self.assertEqual(len(db.containers()), 0) self.assertEqual(len(project.get_service('app').containers()), 1)
self.assertEqual(len(console.containers()), 0) self.assertEqual(len(project.get_service('net').containers()), 1)
self.assertEqual(len(project.get_service('console').containers()), 0)
project.kill()
project.remove_stopped()
def test_project_up_with_no_deps(self):
project = Project.from_config(
name='composetest',
config={
'console': {
'image': 'busybox:latest',
'command': ["/bin/sleep", "300"],
},
'net' : {
'image': 'busybox:latest',
'command': ["/bin/sleep", "300"]
},
'vol': {
'image': 'busybox:latest',
'command': ["/bin/sleep", "300"],
'volumes': ["/tmp"]
},
'app': {
'image': 'busybox:latest',
'command': ["/bin/sleep", "300"],
'net': 'container:net'
},
'web': {
'image': 'busybox:latest',
'command': ["/bin/sleep", "300"],
'net': 'container:net',
'links': ['app'],
'volumes_from': ['vol']
},
},
client=self.client,
)
project.start()
self.assertEqual(len(project.containers()), 0)
project.up(['web'], start_deps=False)
self.assertEqual(len(project.containers(stopped=True)), 2)
self.assertEqual(len(project.get_service('web').containers()), 1)
self.assertEqual(len(project.get_service('vol').containers(stopped=True)), 1)
self.assertEqual(len(project.get_service('net').containers()), 0)
self.assertEqual(len(project.get_service('console').containers()), 0)
project.kill() project.kill()
project.remove_stopped() project.remove_stopped()

View File

@ -2,6 +2,10 @@ from __future__ import unicode_literals
from .. import unittest from .. import unittest
from compose.service import Service from compose.service import Service
from compose.project import Project, ConfigurationError from compose.project import Project, ConfigurationError
from compose.container import Container
import mock
import docker
class ProjectTest(unittest.TestCase): class ProjectTest(unittest.TestCase):
def test_from_dict(self): def test_from_dict(self):
@ -120,7 +124,7 @@ class ProjectTest(unittest.TestCase):
) )
project = Project('test', [web, db, cache, console], None) project = Project('test', [web, db, cache, console], None)
self.assertEqual( self.assertEqual(
project.get_services(['console'], include_links=True), project.get_services(['console'], include_deps=True),
[db, web, console] [db, web, console]
) )
@ -136,6 +140,105 @@ class ProjectTest(unittest.TestCase):
) )
project = Project('test', [web, db], None) project = Project('test', [web, db], None)
self.assertEqual( self.assertEqual(
project.get_services(['web', 'db'], include_links=True), project.get_services(['web', 'db'], include_deps=True),
[db, web] [db, web]
) )
def test_use_volumes_from_container(self):
container_id = 'aabbccddee'
container_dict = dict(Name='aaa', Id=container_id)
mock_client = mock.create_autospec(docker.Client)
mock_client.inspect_container.return_value = container_dict
project = Project.from_dicts('test', [
{
'name': 'test',
'image': 'busybox:latest',
'volumes_from': ['aaa']
}
], mock_client)
self.assertEqual(project.get_service('test')._get_volumes_from(), [container_id])
def test_use_volumes_from_service_no_container(self):
container_name = 'test_vol_1'
mock_client = mock.create_autospec(docker.Client)
mock_client.containers.return_value = [
{
"Name": container_name,
"Names": [container_name],
"Id": container_name,
"Image": 'busybox:latest'
}
]
project = Project.from_dicts('test', [
{
'name': 'vol',
'image': 'busybox:latest'
},
{
'name': 'test',
'image': 'busybox:latest',
'volumes_from': ['vol']
}
], mock_client)
self.assertEqual(project.get_service('test')._get_volumes_from(), [container_name])
@mock.patch.object(Service, 'containers')
def test_use_volumes_from_service_container(self, mock_return):
container_ids = ['aabbccddee', '12345']
mock_return.return_value = [
mock.Mock(id=container_id, spec=Container)
for container_id in container_ids]
project = Project.from_dicts('test', [
{
'name': 'vol',
'image': 'busybox:latest'
},
{
'name': 'test',
'image': 'busybox:latest',
'volumes_from': ['vol']
}
], None)
self.assertEqual(project.get_service('test')._get_volumes_from(), container_ids)
def test_use_net_from_container(self):
container_id = 'aabbccddee'
container_dict = dict(Name='aaa', Id=container_id)
mock_client = mock.create_autospec(docker.Client)
mock_client.inspect_container.return_value = container_dict
project = Project.from_dicts('test', [
{
'name': 'test',
'image': 'busybox:latest',
'net': 'container:aaa'
}
], mock_client)
service = project.get_service('test')
self.assertEqual(service._get_net(), 'container:'+container_id)
def test_use_net_from_service(self):
container_name = 'test_aaa_1'
mock_client = mock.create_autospec(docker.Client)
mock_client.containers.return_value = [
{
"Name": container_name,
"Names": [container_name],
"Id": container_name,
"Image": 'busybox:latest'
}
]
project = Project.from_dicts('test', [
{
'name': 'aaa',
'image': 'busybox:latest'
},
{
'name': 'test',
'image': 'busybox:latest',
'net': 'container:aaa'
}
], mock_client)
service = project.get_service('test')
self.assertEqual(service._get_net(), 'container:'+container_name)

View File

@ -65,6 +65,95 @@ class SortServiceTest(unittest.TestCase):
self.assertEqual(sorted_services[1]['name'], 'parent') self.assertEqual(sorted_services[1]['name'], 'parent')
self.assertEqual(sorted_services[2]['name'], 'grandparent') self.assertEqual(sorted_services[2]['name'], 'grandparent')
def test_sort_service_dicts_4(self):
services = [
{
'name': 'child'
},
{
'name': 'parent',
'volumes_from': ['child']
},
{
'links': ['parent'],
'name': 'grandparent'
},
]
sorted_services = sort_service_dicts(services)
self.assertEqual(len(sorted_services), 3)
self.assertEqual(sorted_services[0]['name'], 'child')
self.assertEqual(sorted_services[1]['name'], 'parent')
self.assertEqual(sorted_services[2]['name'], 'grandparent')
def test_sort_service_dicts_5(self):
services = [
{
'links': ['parent'],
'name': 'grandparent'
},
{
'name': 'parent',
'net': 'container:child'
},
{
'name': 'child'
}
]
sorted_services = sort_service_dicts(services)
self.assertEqual(len(sorted_services), 3)
self.assertEqual(sorted_services[0]['name'], 'child')
self.assertEqual(sorted_services[1]['name'], 'parent')
self.assertEqual(sorted_services[2]['name'], 'grandparent')
def test_sort_service_dicts_6(self):
services = [
{
'links': ['parent'],
'name': 'grandparent'
},
{
'name': 'parent',
'volumes_from': ['child']
},
{
'name': 'child'
}
]
sorted_services = sort_service_dicts(services)
self.assertEqual(len(sorted_services), 3)
self.assertEqual(sorted_services[0]['name'], 'child')
self.assertEqual(sorted_services[1]['name'], 'parent')
self.assertEqual(sorted_services[2]['name'], 'grandparent')
def test_sort_service_dicts_7(self):
services = [
{
'net': 'container:three',
'name': 'four'
},
{
'links': ['two'],
'name': 'three'
},
{
'name': 'two',
'volumes_from': ['one']
},
{
'name': 'one'
}
]
sorted_services = sort_service_dicts(services)
self.assertEqual(len(sorted_services), 4)
self.assertEqual(sorted_services[0]['name'], 'one')
self.assertEqual(sorted_services[1]['name'], 'two')
self.assertEqual(sorted_services[2]['name'], 'three')
self.assertEqual(sorted_services[3]['name'], 'four')
def test_sort_service_dicts_circular_imports(self): def test_sort_service_dicts_circular_imports(self):
services = [ services = [
{ {