diff --git a/fig/service.py b/fig/service.py index 0648aeb87..bbbef7bc4 100644 --- a/fig/service.py +++ b/fig/service.py @@ -251,13 +251,7 @@ class Service(object): def start_container(self, container=None, intermediate_container=None, **override_options): container = container or self.create_container(**override_options) options = dict(self.options, **override_options) - ports = {} - for port in options.get('ports') or []: - internal_port, external = split_port(port) - if internal_port in ports: - ports[internal_port].append(external) - else: - ports[internal_port] = [external] + port_bindings = build_port_bindings(options.get('ports') or []) volume_bindings = dict( build_volume_binding(parse_volume_spec(volume)) @@ -270,7 +264,7 @@ class Service(object): container.start( links=self._get_links(link_to_self=options.get('one_off', False)), - port_bindings=ports, + port_bindings=port_bindings, binds=volume_bindings, volumes_from=self._get_volumes_from(intermediate_container), privileged=privileged, @@ -498,6 +492,17 @@ def build_volume_binding(volume_spec): return os.path.abspath(os.path.expandvars(external)), internal +def build_port_bindings(ports): + port_bindings = {} + for port in ports: + internal_port, external = split_port(port) + if internal_port in port_bindings: + port_bindings[internal_port].append(external) + else: + port_bindings[internal_port] = [external] + return port_bindings + + def split_port(port): parts = str(port).split(':') if not 1 <= len(parts) <= 3: diff --git a/tests/unit/service_test.py b/tests/unit/service_test.py index f1d1c79d9..119b41440 100644 --- a/tests/unit/service_test.py +++ b/tests/unit/service_test.py @@ -13,6 +13,7 @@ from fig.container import Container from fig.service import ( ConfigError, split_port, + build_port_bindings, parse_volume_spec, build_volume_binding, APIError, @@ -114,6 +115,19 @@ class ServiceTest(unittest.TestCase): with self.assertRaises(ConfigError): split_port("0.0.0.0:1000:2000:tcp") + def test_build_port_bindings_with_one_port(self): + port_bindings = build_port_bindings(["127.0.0.1:1000:1000"]) + self.assertEqual(port_bindings["1000"],[("127.0.0.1","1000")]) + + def test_build_port_bindings_with_matching_internal_ports(self): + port_bindings = build_port_bindings(["127.0.0.1:1000:1000","127.0.0.1:2000:1000"]) + self.assertEqual(port_bindings["1000"],[("127.0.0.1","1000"),("127.0.0.1","2000")]) + + def test_build_port_bindings_with_nonmatching_internal_ports(self): + port_bindings = build_port_bindings(["127.0.0.1:1000:1000","127.0.0.1:2000:2000"]) + self.assertEqual(port_bindings["1000"],[("127.0.0.1","1000")]) + self.assertEqual(port_bindings["2000"],[("127.0.0.1","2000")]) + def test_split_domainname_none(self): service = Service('foo', hostname='name', client=self.mock_client) self.mock_client.containers.return_value = []