diff --git a/fig/service.py b/fig/service.py index 0874ec599..7c5bffbe4 100644 --- a/fig/service.py +++ b/fig/service.py @@ -4,6 +4,7 @@ from collections import namedtuple import logging import re import os +from operator import attrgetter import sys from docker.errors import APIError @@ -308,12 +309,17 @@ class Service(object): def _get_volumes_from(self, intermediate_container=None): volumes_from = [] - for v in self.volumes_from: - if isinstance(v, Service): - for container in v.containers(stopped=True): - volumes_from.append(container.id) - elif isinstance(v, Container): - volumes_from.append(v.id) + for volume_source in self.volumes_from: + if isinstance(volume_source, Service): + containers = volume_source.containers(stopped=True) + + if not containers: + volumes_from.append(volume_source.create_container().id) + else: + volumes_from.extend(map(attrgetter('id'), containers)) + + elif isinstance(volume_source, Container): + volumes_from.append(volume_source.id) if intermediate_container: volumes_from.append(intermediate_container.id) diff --git a/tests/unit/service_test.py b/tests/unit/service_test.py index 650afa5a6..2e48248b1 100644 --- a/tests/unit/service_test.py +++ b/tests/unit/service_test.py @@ -8,6 +8,7 @@ import mock import docker from fig import Service +from fig.container import Container from fig.service import ( ConfigError, split_port, @@ -44,6 +45,44 @@ class ServiceTest(unittest.TestCase): self.assertRaises(ConfigError, lambda: Service(name='foo', port=['8000'])) Service(name='foo', ports=['8000']) + def test_get_volumes_from_container(self): + container_id = 'aabbccddee' + service = Service( + 'test', + volumes_from=[mock.Mock(id=container_id, spec=Container)]) + + self.assertEqual(service._get_volumes_from(), [container_id]) + + def test_get_volumes_from_intermediate_container(self): + container_id = 'aabbccddee' + service = Service('test') + container = mock.Mock(id=container_id, spec=Container) + + self.assertEqual(service._get_volumes_from(container), [container_id]) + + def test_get_volumes_from_service_container_exists(self): + container_ids = ['aabbccddee', '12345'] + from_service = mock.create_autospec(Service) + from_service.containers.return_value = [ + mock.Mock(id=container_id, spec=Container) + for container_id in container_ids + ] + service = Service('test', volumes_from=[from_service]) + + self.assertEqual(service._get_volumes_from(), container_ids) + + def test_get_volumes_from_service_no_container(self): + container_id = 'abababab' + from_service = mock.create_autospec(Service) + from_service.containers.return_value = [] + from_service.create_container.return_value = mock.Mock( + id=container_id, + spec=Container) + service = Service('test', volumes_from=[from_service]) + + self.assertEqual(service._get_volumes_from(), [container_id]) + from_service.create_container.assert_called_once_with() + def test_split_port_with_host_ip(self): internal_port, external_port = split_port("127.0.0.1:1000:2000") self.assertEqual(internal_port, "2000")