diff --git a/README.md b/README.md index 585cdd89d..327ac2327 100644 --- a/README.md +++ b/README.md @@ -233,6 +233,15 @@ For example: Note that this will not start any services that the command's service links to. So if, for example, your one-off command talks to your database, you will need to run `fig up -d db` first. +#### scale + +Set number of containers to run for a service. + +Numbers are specified in the form `service=num` as arguments. +For example: + + $ fig scale web=2 worker=3 + #### start Start existing containers for a service. diff --git a/fig/cli/main.py b/fig/cli/main.py index 24a0180d9..44422777f 100644 --- a/fig/cli/main.py +++ b/fig/cli/main.py @@ -10,6 +10,7 @@ from inspect import getdoc from .. import __version__ from ..project import NoSuchService +from ..service import CannotBeScaledError from .command import Command from .formatter import Formatter from .log_printer import LogPrinter @@ -82,6 +83,7 @@ class TopLevelCommand(Command): ps List containers rm Remove stopped containers run Run a one-off command + scale Set number of containers for a service start Start services stop Stop services up Create and start containers @@ -220,6 +222,31 @@ class TopLevelCommand(Command): service.start_container(container, ports=None) c.run() + def scale(self, options): + """ + Set number of containers to run for a service. + + Numbers are specified in the form `service=num` as arguments. + For example: + + $ fig scale web=2 worker=3 + + Usage: scale [SERVICE=NUM...] + """ + for s in options['SERVICE=NUM']: + if '=' not in s: + raise UserError('Arguments to scale should be in the form service=num') + service_name, num = s.split('=', 1) + try: + num = int(num) + except ValueError: + raise UserError('Number of containers for service "%s" is not a number' % service) + try: + self.project.get_service(service_name).scale(num) + except CannotBeScaledError: + raise UserError('Service "%s" cannot be scaled because it specifies a port on the host. If multiple containers for this service were created, the port would clash.\n\nRemove the ":" from the port definition in fig.yml so Docker can choose a random port for each container.' % service_name) + + def start(self, options): """ Start existing containers. diff --git a/fig/service.py b/fig/service.py index 29e867e9c..90194fdd5 100644 --- a/fig/service.py +++ b/fig/service.py @@ -14,6 +14,10 @@ class BuildError(Exception): pass +class CannotBeScaledError(Exception): + pass + + class Service(object): def __init__(self, name, client=None, project='default', links=[], **options): if not re.match('^[a-zA-Z0-9]+$', name): @@ -56,6 +60,40 @@ class Service(object): log.info("Killing %s..." % c.name) c.kill(**options) + def scale(self, desired_num): + if not self.can_be_scaled(): + raise CannotBeScaledError() + + # Create enough containers + containers = self.containers(stopped=True) + while len(containers) < desired_num: + containers.append(self.create_container()) + + running_containers = [] + stopped_containers = [] + for c in containers: + if c.is_running: + running_containers.append(c) + else: + stopped_containers.append(c) + running_containers.sort(key=lambda c: c.number) + stopped_containers.sort(key=lambda c: c.number) + + # Stop containers + while len(running_containers) > desired_num: + c = running_containers.pop() + log.info("Stopping %s..." % c.name) + c.stop(timeout=1) + stopped_containers.append(c) + + # Start containers + while len(running_containers) < desired_num: + c = stopped_containers.pop(0) + log.info("Starting %s..." % c.name) + c.start() + running_containers.append(c) + + def remove_stopped(self, **options): for c in self.containers(stopped=True): if not c.is_running: @@ -231,6 +269,12 @@ class Service(object): """ return '%s_%s' % (self.project, self.name) + def can_be_scaled(self): + for port in self.options.get('ports', []): + if ':' in str(port): + return False + return True + NAME_RE = re.compile(r'^([^_]+)_([^_]+)_(run_)?(\d+)$') diff --git a/tests/cli_test.py b/tests/cli_test.py index 2146a9062..0d9a2f540 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -13,3 +13,26 @@ class CLITestCase(unittest.TestCase): def test_ps(self): self.command.dispatch(['ps'], None) + + def test_scale(self): + project = self.command.project + + self.command.scale({'SERVICE=NUM': ['simple=1']}) + self.assertEqual(len(project.get_service('simple').containers()), 1) + + self.command.scale({'SERVICE=NUM': ['simple=3', 'another=2']}) + self.assertEqual(len(project.get_service('simple').containers()), 3) + self.assertEqual(len(project.get_service('another').containers()), 2) + + self.command.scale({'SERVICE=NUM': ['simple=1', 'another=1']}) + self.assertEqual(len(project.get_service('simple').containers()), 1) + self.assertEqual(len(project.get_service('another').containers()), 1) + + self.command.scale({'SERVICE=NUM': ['simple=1', 'another=1']}) + self.assertEqual(len(project.get_service('simple').containers()), 1) + self.assertEqual(len(project.get_service('another').containers()), 1) + + self.command.scale({'SERVICE=NUM': ['simple=0', 'another=0']}) + self.assertEqual(len(project.get_service('simple').containers()), 0) + self.assertEqual(len(project.get_service('another').containers()), 0) + diff --git a/tests/fixtures/simple-figfile/fig.yml b/tests/fixtures/simple-figfile/fig.yml index aef2d39ba..225323755 100644 --- a/tests/fixtures/simple-figfile/fig.yml +++ b/tests/fixtures/simple-figfile/fig.yml @@ -1,2 +1,6 @@ simple: image: ubuntu + command: /bin/sleep 300 +another: + image: ubuntu + command: /bin/sleep 300 diff --git a/tests/service_test.py b/tests/service_test.py index 2ccf17555..ff7b24160 100644 --- a/tests/service_test.py +++ b/tests/service_test.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals from __future__ import absolute_import from fig import Service +from fig.service import CannotBeScaledError from .testcases import DockerClientTestCase @@ -193,3 +194,20 @@ class ServiceTest(DockerClientTestCase): self.assertIn('8000/tcp', container['HostConfig']['PortBindings']) self.assertEqual(container['HostConfig']['PortBindings']['8000/tcp'][0]['HostPort'], '8001') + def test_scale(self): + service = self.create_service('web') + service.scale(1) + self.assertEqual(len(service.containers()), 1) + service.scale(3) + self.assertEqual(len(service.containers()), 3) + service.scale(1) + self.assertEqual(len(service.containers()), 1) + service.scale(0) + self.assertEqual(len(service.containers()), 0) + + def test_scale_on_service_that_cannot_be_scaled(self): + service = self.create_service('web', ports=['8000:8000']) + self.assertRaises(CannotBeScaledError, lambda: service.scale(1)) + + +