diff --git a/fig/cli/main.py b/fig/cli/main.py index 9585371df..4700803d1 100644 --- a/fig/cli/main.py +++ b/fig/cli/main.py @@ -301,10 +301,9 @@ class TopLevelCommand(Command): """ detached = options['-d'] - new = self.project.up(service_names=options['SERVICE']) + to_attach = self.project.up(service_names=options['SERVICE']) if not detached: - to_attach = [c for (s, c) in new] print("Attaching to", list_containers(to_attach)) log_printer = LogPrinter(to_attach, attach_params={"logs": True}) diff --git a/fig/packages/docker/client.py b/fig/packages/docker/client.py index 77bb962f8..8b447d785 100644 --- a/fig/packages/docker/client.py +++ b/fig/packages/docker/client.py @@ -698,8 +698,8 @@ class Client(requests.Session): params={'term': term}), True) - def start(self, container, binds=None, port_bindings=None, lxc_conf=None, - publish_all_ports=False, links=None, privileged=False): + def start(self, container, binds=None, volumes_from=None, port_bindings=None, + lxc_conf=None, publish_all_ports=False, links=None, privileged=False): if isinstance(container, dict): container = container.get('Id') @@ -718,6 +718,11 @@ class Client(requests.Session): ] start_config['Binds'] = bind_pairs + if volumes_from and not isinstance(volumes_from, six.string_types): + volumes_from = ','.join(volumes_from) + + start_config['VolumesFrom'] = volumes_from + if port_bindings: start_config['PortBindings'] = utils.convert_port_bindings( port_bindings diff --git a/fig/project.py b/fig/project.py index 38bbba222..b271a810a 100644 --- a/fig/project.py +++ b/fig/project.py @@ -105,23 +105,6 @@ class Project(object): unsorted = [self.get_service(name) for name in service_names] return [s for s in self.services if s in unsorted] - def recreate_containers(self, service_names=None): - """ - For each service, create or recreate their containers. - Returns a tuple with two lists. The first is a list of - (service, old_container) tuples; the second is a list - of (service, new_container) tuples. - """ - old = [] - new = [] - - for service in self.get_services(service_names): - (s_old, s_new) = service.recreate_containers() - old += [(service, container) for container in s_old] - new += [(service, container) for container in s_new] - - return (old, new) - def start(self, service_names=None, **options): for service in self.get_services(service_names): service.start(**options) @@ -142,15 +125,13 @@ class Project(object): log.info('%s uses an image, skipping' % service.name) def up(self, service_names=None): - (old, new) = self.recreate_containers(service_names=service_names) + new_containers = [] - for (service, container) in new: - service.start_container(container) + for service in self.get_services(service_names): + for (_, new) in service.recreate_containers(): + new_containers.append(new) - for (service, container) in old: - container.remove() - - return new + return new_containers def remove_stopped(self, service_names=None, **options): for service in self.get_services(service_names): diff --git a/fig/service.py b/fig/service.py index 3e032fa8f..20c4e120f 100644 --- a/fig/service.py +++ b/fig/service.py @@ -154,25 +154,24 @@ class Service(object): def recreate_containers(self, **override_options): """ - If a container for this service doesn't exist, create one. If there are - any, stop them and create new ones. Does not remove the old containers. + If a container for this service doesn't exist, create and start one. If there are + any, stop them, create+start new ones, and remove the old containers. """ containers = self.containers(stopped=True) if len(containers) == 0: log.info("Creating %s..." % self.next_container_name()) - return ([], [self.create_container(**override_options)]) + container = self.create_container(**override_options) + self.start_container(container) + return [(None, container)] else: - old_containers = [] - new_containers = [] + tuples = [] for c in containers: log.info("Recreating %s..." % c.name) - (old_container, new_container) = self.recreate_container(c, **override_options) - old_containers.append(old_container) - new_containers.append(new_container) + tuples.append(self.recreate_container(c, **override_options)) - return (old_containers, new_containers) + return tuples def recreate_container(self, container, **override_options): if container.is_running: @@ -185,17 +184,20 @@ class Service(object): entrypoint=['echo'], command=[], ) - intermediate_container.start() + intermediate_container.start(volumes_from=container.id) intermediate_container.wait() container.remove() options = dict(override_options) options['volumes_from'] = intermediate_container.id new_container = self.create_container(**options) + self.start_container(new_container, volumes_from=intermediate_container.id) + + intermediate_container.remove() return (intermediate_container, new_container) - def start_container(self, container=None, **override_options): + def start_container(self, container=None, volumes_from=None, **override_options): if container is None: container = self.create_container(**override_options) @@ -228,6 +230,7 @@ class Service(object): links=self._get_links(link_to_self=override_options.get('one_off', False)), port_bindings=port_bindings, binds=volume_bindings, + volumes_from=volumes_from, privileged=privileged, ) return container diff --git a/tests/project_test.py b/tests/project_test.py index b8a5d6823..bde40e89b 100644 --- a/tests/project_test.py +++ b/tests/project_test.py @@ -63,29 +63,6 @@ class ProjectTest(DockerClientTestCase): project = Project('test', [web], self.client) self.assertEqual(project.get_service('web'), web) - def test_recreate_containers(self): - web = self.create_service('web') - db = self.create_service('db') - project = Project('test', [web, db], self.client) - - old_web_container = web.create_container() - self.assertEqual(len(web.containers(stopped=True)), 1) - self.assertEqual(len(db.containers(stopped=True)), 0) - - (old, new) = project.recreate_containers() - self.assertEqual(len(old), 1) - self.assertEqual(old[0][0], web) - self.assertEqual(len(new), 2) - self.assertEqual(new[0][0], web) - self.assertEqual(new[1][0], db) - - self.assertEqual(len(web.containers(stopped=True)), 1) - self.assertEqual(len(db.containers(stopped=True)), 1) - - # remove intermediate containers - for (service, container) in old: - container.remove() - def test_start_stop_kill_remove(self): web = self.create_service('web') db = self.create_service('db') @@ -121,12 +98,23 @@ class ProjectTest(DockerClientTestCase): def test_project_up(self): web = self.create_service('web') - db = self.create_service('db') + db = self.create_service('db', volumes=['/var/db']) project = Project('figtest', [web, db], self.client) project.start() self.assertEqual(len(project.containers()), 0) + + project.up(['db']) + self.assertEqual(len(project.containers()), 1) + old_db_id = project.containers()[0].id + db_volume_path = project.containers()[0].inspect()['Volumes']['/var/db'] + project.up() self.assertEqual(len(project.containers()), 2) + + db_container = [c for c in project.containers() if 'db' in c.name][0] + self.assertNotEqual(c.id, old_db_id) + self.assertEqual(c.inspect()['Volumes']['/var/db'], db_volume_path) + project.kill() project.remove_stopped() diff --git a/tests/service_test.py b/tests/service_test.py index 5e8fe3ba3..a8ea017a8 100644 --- a/tests/service_test.py +++ b/tests/service_test.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals from __future__ import absolute_import from fig import Service from fig.service import CannotBeScaledError, ConfigError +from fig.packages.docker.client import APIError from .testcases import DockerClientTestCase @@ -132,23 +133,22 @@ class ServiceTest(DockerClientTestCase): num_containers_before = len(self.client.containers(all=True)) service.options['environment']['FOO'] = '2' - (intermediate, new) = service.recreate_containers() - self.assertEqual(len(intermediate), 1) - self.assertEqual(len(new), 1) + tuples = service.recreate_containers() + self.assertEqual(len(tuples), 1) - new_container = new[0] - intermediate_container = intermediate[0] + intermediate_container = tuples[0][0] + new_container = tuples[0][1] self.assertEqual(intermediate_container.dictionary['Config']['Entrypoint'], ['echo']) self.assertEqual(new_container.dictionary['Config']['Entrypoint'], ['ps']) self.assertEqual(new_container.dictionary['Config']['Cmd'], ['ax']) self.assertIn('FOO=2', new_container.dictionary['Config']['Env']) self.assertEqual(new_container.name, 'figtest_db_1') - service.start_container(new_container) self.assertEqual(new_container.inspect()['Volumes']['/var/db'], volume_path) - self.assertEqual(len(self.client.containers(all=True)), num_containers_before + 1) + self.assertEqual(len(self.client.containers(all=True)), num_containers_before) self.assertNotEqual(old_container.id, new_container.id) + self.assertRaises(APIError, lambda: self.client.inspect_container(intermediate_container.id)) def test_start_container_passes_through_options(self): db = self.create_service('db')