diff --git a/fig/cli/main.py b/fig/cli/main.py index 66d2bdf28..909e3e654 100644 --- a/fig/cli/main.py +++ b/fig/cli/main.py @@ -8,7 +8,7 @@ import signal from inspect import getdoc from .. import __version__ -from ..project import NoSuchService +from ..project import NoSuchService, DependencyError from ..service import CannotBeScaledError from .command import Command from .formatter import Formatter @@ -40,10 +40,7 @@ def main(): except KeyboardInterrupt: log.error("\nAborting.") exit(1) - except UserError as e: - log.error(e.msg) - exit(1) - except NoSuchService as e: + except (UserError, NoSuchService, DependencyError) as e: log.error(e.msg) exit(1) except NoSuchCommand as e: diff --git a/fig/compat/__init__.py b/fig/compat/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/fig/compat/functools.py b/fig/compat/functools.py deleted file mode 100644 index 38a48c334..000000000 --- a/fig/compat/functools.py +++ /dev/null @@ -1,23 +0,0 @@ - -# Taken from python2.7/3.3 functools -def cmp_to_key(mycmp): - """Convert a cmp= function into a key= function""" - class K(object): - __slots__ = ['obj'] - def __init__(self, obj): - self.obj = obj - def __lt__(self, other): - return mycmp(self.obj, other.obj) < 0 - def __gt__(self, other): - return mycmp(self.obj, other.obj) > 0 - def __eq__(self, other): - return mycmp(self.obj, other.obj) == 0 - def __le__(self, other): - return mycmp(self.obj, other.obj) <= 0 - def __ge__(self, other): - return mycmp(self.obj, other.obj) >= 0 - def __ne__(self, other): - return mycmp(self.obj, other.obj) != 0 - __hash__ = None - return K - diff --git a/fig/project.py b/fig/project.py index f77da5f7d..157044cbf 100644 --- a/fig/project.py +++ b/fig/project.py @@ -2,21 +2,36 @@ from __future__ import unicode_literals from __future__ import absolute_import import logging from .service import Service -from .compat.functools import cmp_to_key log = logging.getLogger(__name__) + def sort_service_dicts(services): - # Sort in dependency order - def cmp(x, y): - x_deps_y = y['name'] in x.get('links', []) - y_deps_x = x['name'] in y.get('links', []) - if x_deps_y and not y_deps_x: - return 1 - elif y_deps_x and not x_deps_y: - return -1 - return 0 - return sorted(services, key=cmp_to_key(cmp)) + # Get all services that are dependant on another. + dependent_services = [s for s in services if s.get('links')] + flatten_links = sum([s['links'] for s in dependent_services], []) + # Get all services that are not linked to and don't link to others. + non_dependent_sevices = [s for s in services if s['name'] not in flatten_links and not s.get('links')] + sorted_services = [] + # Topological sort. + while dependent_services: + n = dependent_services.pop() + # Check if a service is dependent on itself, if so raise an error. + if n['name'] in n.get('links', []): + raise DependencyError('A service can not link to itself: %s' % n['name']) + sorted_services.append(n) + for l in n['links']: + # Get the linked service. + linked_service = next(s for s in services if l == s['name']) + # Check that there isn't a circular import between services. + if n['name'] in linked_service.get('links', []): + raise DependencyError('Circular import between %s and %s' % (n['name'], linked_service['name'])) + # Check the linked service has no links and is not already in the + # sorted service list. + if not linked_service.get('links') and linked_service not in sorted_services: + sorted_services.insert(0, linked_service) + return non_dependent_sevices + sorted_services + class Project(object): """ @@ -134,3 +149,11 @@ class NoSuchService(Exception): def __str__(self): return self.msg + + +class DependencyError(Exception): + def __init__(self, msg): + self.msg = msg + + def __str__(self): + return self.msg \ No newline at end of file diff --git a/tests/sort_service_test.py b/tests/sort_service_test.py new file mode 100644 index 000000000..d5218cd36 --- /dev/null +++ b/tests/sort_service_test.py @@ -0,0 +1,103 @@ +from fig.project import sort_service_dicts, DependencyError +from . import unittest + + +class SortServiceTest(unittest.TestCase): + def test_sort_service_dicts_1(self): + services = [ + { + 'links': ['redis'], + 'name': 'web' + }, + { + 'name': 'grunt' + }, + { + 'name': 'redis' + } + ] + + sorted_services = sort_service_dicts(services) + self.assertEqual(len(sorted_services), 3) + self.assertEqual(sorted_services[0]['name'], 'grunt') + self.assertEqual(sorted_services[1]['name'], 'redis') + self.assertEqual(sorted_services[2]['name'], 'web') + + def test_sort_service_dicts_2(self): + services = [ + { + 'links': ['redis', 'postgres'], + 'name': 'web' + }, + { + 'name': 'postgres', + 'links': ['redis'] + }, + { + 'name': 'redis' + } + ] + + sorted_services = sort_service_dicts(services) + self.assertEqual(len(sorted_services), 3) + self.assertEqual(sorted_services[0]['name'], 'redis') + self.assertEqual(sorted_services[1]['name'], 'postgres') + self.assertEqual(sorted_services[2]['name'], 'web') + + def test_sort_service_dicts_circular_imports(self): + services = [ + { + 'links': ['redis'], + 'name': 'web' + }, + { + 'name': 'redis', + 'links': ['web'] + }, + ] + + try: + sort_service_dicts(services) + except DependencyError as e: + self.assertIn('redis', e.msg) + self.assertIn('web', e.msg) + else: + self.fail('Should have thrown an DependencyError') + + def test_sort_service_dicts_circular_imports_2(self): + services = [ + { + 'links': ['postgres', 'redis'], + 'name': 'web' + }, + { + 'name': 'redis', + 'links': ['web'] + }, + { + 'name': 'postgres' + } + ] + + try: + sort_service_dicts(services) + except DependencyError as e: + self.assertIn('redis', e.msg) + self.assertIn('web', e.msg) + else: + self.fail('Should have thrown an DependencyError') + + def test_sort_service_dicts_self_imports(self): + services = [ + { + 'links': ['web'], + 'name': 'web' + }, + ] + + try: + sort_service_dicts(services) + except DependencyError as e: + self.assertIn('web', e.msg) + else: + self.fail('Should have thrown an DependencyError')