diff --git a/fig/project.py b/fig/project.py index 4e74c3afe..5c798d384 100644 --- a/fig/project.py +++ b/fig/project.py @@ -139,13 +139,17 @@ class Project(object): log.info('%s uses an image, skipping' % service.name) def up(self, service_names=None, start_links=True, keep_old=False): - new_containers = [] + running_containers = [] for service in self.get_services(service_names, include_links=start_links): - for (_, new) in service.recreate_containers(keep_old): - new_containers.append(new) + if keep_old: + for container in service.start_or_create_containers(): + running_containers.append(container) + else: + for (_, container) in service.recreate_containers(): + running_containers.append(container) - return new_containers + return running_containers def remove_stopped(self, service_names=None, **options): for service in self.get_services(service_names): diff --git a/fig/service.py b/fig/service.py index cd9cbf832..7de0ae7dc 100644 --- a/fig/service.py +++ b/fig/service.py @@ -154,7 +154,7 @@ class Service(object): return Container.create(self.client, **container_options) raise - def recreate_containers(self, keep_old=False, **override_options): + def recreate_containers(self, **override_options): """ If a container for this service doesn't exist, create and start one. If there are any, stop them, create+start new ones, and remove the old containers. @@ -166,8 +166,6 @@ class Service(object): container = self.create_container(**override_options) self.start_container(container) return [(None, container)] - elif keep_old: - return [(None, self.start_container_if_stopped(c)) for c in containers] else: tuples = [] @@ -249,6 +247,16 @@ class Service(object): ) return container + def start_or_create_containers(self): + containers = self.containers(stopped=True) + + if len(containers) == 0: + log.info("Creating %s..." % self.next_container_name()) + new_container = self.create_container() + return [self.start_container(new_container)] + else: + return [self.start_container_if_stopped(c) for c in containers] + def get_linked_names(self): return [s.name for (s, _) in self.links] diff --git a/tests/integration/project_test.py b/tests/integration/project_test.py index f2b9075ef..0c5c1aa76 100644 --- a/tests/integration/project_test.py +++ b/tests/integration/project_test.py @@ -74,7 +74,7 @@ class ProjectTest(DockerClientTestCase): project.kill() project.remove_stopped() - def test_project_up_with_keep_old(self): + def test_project_up_with_keep_old_running(self): web = self.create_service('web') db = self.create_service('db', volumes=['/var/db']) project = Project('figtest', [web, db], self.client) @@ -96,6 +96,34 @@ class ProjectTest(DockerClientTestCase): project.kill() project.remove_stopped() + def test_project_up_with_keep_old_stopped(self): + web = self.create_service('web') + db = self.create_service('db', volumes=['/var/db']) + project = Project('figtest', [web, db], self.client) + project.start() + self.assertEqual(len(project.containers()), 0) + + project.up(['db']) + project.stop() + + old_containers = project.containers(stopped=True) + + self.assertEqual(len(old_containers), 1) + old_db_id = old_containers[0].id + db_volume_path = old_containers[0].inspect()['Volumes']['/var/db'] + + project.up(keep_old=True) + + new_containers = project.containers(stopped=True) + self.assertEqual(len(new_containers), 2) + + db_container = [c for c in new_containers if 'db' in c.name][0] + self.assertEqual(c.id, old_db_id) + self.assertEqual(c.inspect()['Volumes']['/var/db'], db_volume_path) + + project.kill() + project.remove_stopped() + def test_project_up_without_auto_start(self): console = self.create_service('console', auto_start=False) db = self.create_service('db') diff --git a/tests/integration/service_test.py b/tests/integration/service_test.py index 8f4d3f791..78ddbd850 100644 --- a/tests/integration/service_test.py +++ b/tests/integration/service_test.py @@ -132,52 +132,6 @@ class ServiceTest(DockerClientTestCase): self.assertNotEqual(old_container.id, new_container.id) self.assertRaises(APIError, lambda: self.client.inspect_container(intermediate_container.id)) - def test_recreate_containers_with_keep_old_running(self): - service = self.create_service( - 'db', - environment={'FOO': '1'}, - volumes=['/var/db'], - entrypoint=['ps'], - command=['ax'] - ) - old_container = service.create_container() - service.start_container(old_container) - - num_containers_before = len(self.client.containers(all=True)) - - tuples = service.recreate_containers(keep_old=True) - self.assertEqual(len(tuples), 1) - - intermediate_container = tuples[0][0] - new_container = tuples[0][1] - - self.assertIsNone(intermediate_container) - self.assertEqual(len(self.client.containers(all=True)), num_containers_before) - self.assertEqual(old_container.id, new_container.id) - - def test_recreate_containers_with_keep_old_stopped(self): - service = self.create_service( - 'db', - environment={'FOO': '1'}, - volumes=['/var/db'], - entrypoint=['ps'], - command=['ax'] - ) - old_container = service.create_container() - old_container.stop() - - num_containers_before = len(self.client.containers(all=True)) - - tuples = service.recreate_containers(keep_old=True) - self.assertEqual(len(tuples), 1) - - intermediate_container = tuples[0][0] - new_container = tuples[0][1] - - self.assertIsNone(intermediate_container) - self.assertEqual(len(self.client.containers(all=True)), num_containers_before) - self.assertEqual(old_container.id, new_container.id) - def test_start_container_passes_through_options(self): db = self.create_service('db') db.start_container(environment={'FOO': 'BAR'})