From 79df2ebe1bbe81232acd84eeca7bf66af8e3004b Mon Sep 17 00:00:00 2001
From: Daniel Nephin <dnephin@docker.com>
Date: Wed, 13 Jan 2016 15:19:02 -0500
Subject: [PATCH] Support variable interpolation for volumes and networks
 sections.

Signed-off-by: Daniel Nephin <dnephin@docker.com>
---
 compose/config/config.py                | 21 +++++---
 compose/config/interpolation.py         | 31 +++++------
 tests/unit/config/config_test.py        | 16 +++---
 tests/unit/config/interpolation_test.py | 69 +++++++++++++++++++++++++
 4 files changed, 104 insertions(+), 33 deletions(-)
 create mode 100644 tests/unit/config/interpolation_test.py

diff --git a/compose/config/config.py b/compose/config/config.py
index c8d93faf6..f6df3d3bf 100644
--- a/compose/config/config.py
+++ b/compose/config/config.py
@@ -138,6 +138,9 @@ class ConfigFile(namedtuple('_ConfigFile', 'filename config')):
     def get_volumes(self):
         return {} if self.version == 1 else self.config.get('volumes', {})
 
+    def get_networks(self):
+        return {} if self.version == 1 else self.config.get('networks', {})
+
 
 class Config(namedtuple('_Config', 'version services volumes networks')):
     """
@@ -258,8 +261,8 @@ def load(config_details):
     config_details = config_details._replace(config_files=processed_files)
 
     main_file = config_details.config_files[0]
-    volumes = load_mapping(config_details.config_files, 'volumes', 'Volume')
-    networks = load_mapping(config_details.config_files, 'networks', 'Network')
+    volumes = load_mapping(config_details.config_files, 'get_volumes', 'Volume')
+    networks = load_mapping(config_details.config_files, 'get_networks', 'Network')
     service_dicts = load_services(
         config_details.working_dir,
         main_file.filename,
@@ -268,11 +271,11 @@ def load(config_details):
     return Config(main_file.version, service_dicts, volumes, networks)
 
 
-def load_mapping(config_files, key, entity_type):
+def load_mapping(config_files, get_func, entity_type):
     mapping = {}
 
     for config_file in config_files:
-        for name, config in config_file.config.get(key, {}).items():
+        for name, config in getattr(config_file, get_func)().items():
             mapping[name] = config or {}
             if not config:
                 continue
@@ -347,12 +350,16 @@ def process_config_file(config_file, service_name=None):
     service_dicts = config_file.get_service_dicts()
     validate_top_level_service_objects(config_file.filename, service_dicts)
 
-    # TODO: interpolate config in volumes/network sections as well
-    interpolated_config = interpolate_environment_variables(service_dicts)
+    interpolated_config = interpolate_environment_variables(service_dicts, 'service')
 
     if config_file.version == 2:
         processed_config = dict(config_file.config)
-        processed_config.update({'services': interpolated_config})
+        processed_config['services'] = interpolated_config
+        processed_config['volumes'] = interpolate_environment_variables(
+            config_file.get_volumes(), 'volume')
+        processed_config['networks'] = interpolate_environment_variables(
+            config_file.get_networks(), 'network')
+
     if config_file.version == 1:
         processed_config = interpolated_config
 
diff --git a/compose/config/interpolation.py b/compose/config/interpolation.py
index 7a7576448..e1c781fec 100644
--- a/compose/config/interpolation.py
+++ b/compose/config/interpolation.py
@@ -11,35 +11,32 @@ from .errors import ConfigurationError
 log = logging.getLogger(__name__)
 
 
-def interpolate_environment_variables(service_dicts):
+def interpolate_environment_variables(config, section):
     mapping = BlankDefaultDict(os.environ)
 
+    def process_item(name, config_dict):
+        return dict(
+            (key, interpolate_value(name, key, val, section, mapping))
+            for key, val in (config_dict or {}).items()
+        )
+
     return dict(
-        (service_name, process_service(service_name, service_dict, mapping))
-        for (service_name, service_dict) in service_dicts.items()
+        (name, process_item(name, config_dict))
+        for name, config_dict in config.items()
     )
 
 
-def process_service(service_name, service_dict, mapping):
-    return dict(
-        (key, interpolate_value(service_name, key, val, mapping))
-        for (key, val) in service_dict.items()
-    )
-
-
-def interpolate_value(service_name, config_key, value, mapping):
+def interpolate_value(name, config_key, value, section, mapping):
     try:
         return recursive_interpolate(value, mapping)
     except InvalidInterpolation as e:
         raise ConfigurationError(
             'Invalid interpolation format for "{config_key}" option '
-            'in service "{service_name}": "{string}"'
-            .format(
+            'in {section} "{name}": "{string}"'.format(
                 config_key=config_key,
-                service_name=service_name,
-                string=e.string,
-            )
-        )
+                name=name,
+                section=section,
+                string=e.string))
 
 
 def recursive_interpolate(obj, mapping):
diff --git a/tests/unit/config/config_test.py b/tests/unit/config/config_test.py
index f0d432d58..f88166432 100644
--- a/tests/unit/config/config_test.py
+++ b/tests/unit/config/config_test.py
@@ -686,8 +686,8 @@ class ConfigTest(unittest.TestCase):
             )
         )
 
-        self.assertTrue(mock_logging.warn.called)
-        self.assertTrue(expected_warning_msg in mock_logging.warn.call_args[0][0])
+        assert mock_logging.warn.called
+        assert expected_warning_msg in mock_logging.warn.call_args[0][0]
 
     def test_config_valid_environment_dict_key_contains_dashes(self):
         services = config.load(
@@ -1664,15 +1664,13 @@ class ExtendsTest(unittest.TestCase):
             load_from_filename('tests/fixtures/extends/invalid-net.yml')
 
     @mock.patch.dict(os.environ)
-    def test_valid_interpolation_in_extended_service(self):
-        os.environ.update(
-            HOSTNAME_VALUE="penguin",
-        )
+    def test_load_config_runs_interpolation_in_extended_service(self):
+        os.environ.update(HOSTNAME_VALUE="penguin")
         expected_interpolated_value = "host-penguin"
-
-        service_dicts = load_from_filename('tests/fixtures/extends/valid-interpolation.yml')
+        service_dicts = load_from_filename(
+            'tests/fixtures/extends/valid-interpolation.yml')
         for service in service_dicts:
-            self.assertTrue(service['hostname'], expected_interpolated_value)
+            assert service['hostname'] == expected_interpolated_value
 
     @pytest.mark.xfail(IS_WINDOWS_PLATFORM, reason='paths use slash')
     def test_volume_path(self):
diff --git a/tests/unit/config/interpolation_test.py b/tests/unit/config/interpolation_test.py
new file mode 100644
index 000000000..0691e8865
--- /dev/null
+++ b/tests/unit/config/interpolation_test.py
@@ -0,0 +1,69 @@
+from __future__ import absolute_import
+from __future__ import unicode_literals
+
+import os
+
+import mock
+import pytest
+
+from compose.config.interpolation import interpolate_environment_variables
+
+
+@pytest.yield_fixture
+def mock_env():
+    with mock.patch.dict(os.environ):
+        os.environ['USER'] = 'jenny'
+        os.environ['FOO'] = 'bar'
+        yield
+
+
+def test_interpolate_environment_variables_in_services(mock_env):
+    services = {
+        'servivea': {
+            'image': 'example:${USER}',
+            'volumes': ['$FOO:/target'],
+            'logging': {
+                'driver': '${FOO}',
+                'options': {
+                    'user': '$USER',
+                }
+            }
+        }
+    }
+    expected = {
+        'servivea': {
+            'image': 'example:jenny',
+            'volumes': ['bar:/target'],
+            'logging': {
+                'driver': 'bar',
+                'options': {
+                    'user': 'jenny',
+                }
+            }
+        }
+    }
+    assert interpolate_environment_variables(services, 'service') == expected
+
+
+def test_interpolate_environment_variables_in_volumes(mock_env):
+    volumes = {
+        'data': {
+            'driver': '$FOO',
+            'driver_opts': {
+                'max': 2,
+                'user': '${USER}'
+            }
+        },
+        'other': None,
+    }
+    expected = {
+        'data': {
+            'driver': 'bar',
+            'driver_opts': {
+                'max': 2,
+                'user': 'jenny'
+            }
+        },
+        'other': {},
+    }
+    assert interpolate_environment_variables(volumes, 'volume') == expected