diff --git a/fig/service.py b/fig/service.py index bbb1bb669..48bf3e085 100644 --- a/fig/service.py +++ b/fig/service.py @@ -5,6 +5,7 @@ from .packages.docker.errors import APIError import logging import re import os +from operator import attrgetter import sys from .container import Container from .progress_stream import stream_output, StreamOutputError @@ -285,12 +286,17 @@ class Service(object): def _get_volumes_from(self, intermediate_container=None): volumes_from = [] - for v in self.volumes_from: - if isinstance(v, Service): - for container in v.containers(stopped=True): - volumes_from.append(container.id) - elif isinstance(v, Container): - volumes_from.append(v.id) + for volume_source in self.volumes_from: + if isinstance(volume_source, Service): + containers = volume_source.containers(stopped=True) + + if not containers: + volumes_from.append(volume_source.create_container().id) + else: + volumes_from.extend(map(attrgetter('id'), containers)) + + elif isinstance(volume_source, Container): + volumes_from.append(volume_source.id) if intermediate_container: volumes_from.append(intermediate_container.id) diff --git a/tests/unit/service_test.py b/tests/unit/service_test.py index 84f589b2d..028eefecd 100644 --- a/tests/unit/service_test.py +++ b/tests/unit/service_test.py @@ -6,6 +6,7 @@ from .. import unittest import mock from fig import Service +from fig.container import Container from fig.service import ( ConfigError, split_port, @@ -38,6 +39,44 @@ class ServiceTest(unittest.TestCase): self.assertRaises(ConfigError, lambda: Service(name='foo', port=['8000'])) Service(name='foo', ports=['8000']) + def test_get_volumes_from_container(self): + container_id = 'aabbccddee' + service = Service( + 'test', + volumes_from=[mock.Mock(id=container_id, spec=Container)]) + + self.assertEqual(service._get_volumes_from(), [container_id]) + + def test_get_volumes_from_intermediate_container(self): + container_id = 'aabbccddee' + service = Service('test') + container = mock.Mock(id=container_id, spec=Container) + + self.assertEqual(service._get_volumes_from(container), [container_id]) + + def test_get_volumes_from_service_container_exists(self): + container_ids = ['aabbccddee', '12345'] + from_service = mock.create_autospec(Service) + from_service.containers.return_value = [ + mock.Mock(id=container_id, spec=Container) + for container_id in container_ids + ] + service = Service('test', volumes_from=[from_service]) + + self.assertEqual(service._get_volumes_from(), container_ids) + + def test_get_volumes_from_service_no_container(self): + container_id = 'abababab' + from_service = mock.create_autospec(Service) + from_service.containers.return_value = [] + from_service.create_container.return_value = mock.Mock( + id=container_id, + spec=Container) + service = Service('test', volumes_from=[from_service]) + + self.assertEqual(service._get_volumes_from(), [container_id]) + from_service.create_container.assert_called_once_with() + def test_split_port_with_host_ip(self): internal_port, external_port = split_port("127.0.0.1:1000:2000") self.assertEqual(internal_port, "2000")