Implement environment singleton to be accessed throughout the code

Load and parse environment file from working dir

Signed-off-by: Joffrey F <joffrey@docker.com>
This commit is contained in:
Joffrey F 2016-03-01 15:49:42 -08:00
parent d1d8df7f72
commit c69d8a3bd2
13 changed files with 151 additions and 64 deletions

View File

@ -21,7 +21,7 @@ log = logging.getLogger(__name__)
def project_from_options(project_dir, options): def project_from_options(project_dir, options):
return get_project( return get_project(
project_dir, project_dir,
get_config_path_from_options(options), get_config_path_from_options(project_dir, options),
project_name=options.get('--project-name'), project_name=options.get('--project-name'),
verbose=options.get('--verbose'), verbose=options.get('--verbose'),
host=options.get('--host'), host=options.get('--host'),
@ -29,12 +29,13 @@ def project_from_options(project_dir, options):
) )
def get_config_path_from_options(options): def get_config_path_from_options(base_dir, options):
file_option = options.get('--file') file_option = options.get('--file')
if file_option: if file_option:
return file_option return file_option
config_files = os.environ.get('COMPOSE_FILE') environment = config.environment.get_instance(base_dir)
config_files = environment.get('COMPOSE_FILE')
if config_files: if config_files:
return config_files.split(os.pathsep) return config_files.split(os.pathsep)
return None return None
@ -57,8 +58,9 @@ def get_project(project_dir, config_path=None, project_name=None, verbose=False,
config_details = config.find(project_dir, config_path) config_details = config.find(project_dir, config_path)
project_name = get_project_name(config_details.working_dir, project_name) project_name = get_project_name(config_details.working_dir, project_name)
config_data = config.load(config_details) config_data = config.load(config_details)
environment = config.environment.get_instance(project_dir)
api_version = os.environ.get( api_version = environment.get(
'COMPOSE_API_VERSION', 'COMPOSE_API_VERSION',
API_VERSIONS[config_data.version]) API_VERSIONS[config_data.version])
client = get_client( client = get_client(
@ -73,7 +75,8 @@ def get_project_name(working_dir, project_name=None):
def normalize_name(name): def normalize_name(name):
return re.sub(r'[^a-z0-9]', '', name.lower()) return re.sub(r'[^a-z0-9]', '', name.lower())
project_name = project_name or os.environ.get('COMPOSE_PROJECT_NAME') environment = config.environment.get_instance(working_dir)
project_name = project_name or environment.get('COMPOSE_PROJECT_NAME')
if project_name: if project_name:
return normalize_name(project_name) return normalize_name(project_name)

View File

@ -222,7 +222,7 @@ class TopLevelCommand(object):
--services Print the service names, one per line. --services Print the service names, one per line.
""" """
config_path = get_config_path_from_options(config_options) config_path = get_config_path_from_options(self.project_dir, config_options)
compose_config = config.load(config.find(self.project_dir, config_path)) compose_config = config.load(config.find(self.project_dir, config_path))
if options['--quiet']: if options['--quiet']:

View File

@ -2,6 +2,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
from . import environment
from .config import ConfigurationError from .config import ConfigurationError
from .config import DOCKER_CONFIG_KEYS from .config import DOCKER_CONFIG_KEYS
from .config import find from .config import find

View File

@ -17,6 +17,7 @@ from cached_property import cached_property
from ..const import COMPOSEFILE_V1 as V1 from ..const import COMPOSEFILE_V1 as V1
from ..const import COMPOSEFILE_V2_0 as V2_0 from ..const import COMPOSEFILE_V2_0 as V2_0
from ..utils import build_string_dict from ..utils import build_string_dict
from .environment import Environment
from .errors import CircularReference from .errors import CircularReference
from .errors import ComposeFileNotFound from .errors import ComposeFileNotFound
from .errors import ConfigurationError from .errors import ConfigurationError
@ -211,7 +212,8 @@ def find(base_dir, filenames):
if filenames == ['-']: if filenames == ['-']:
return ConfigDetails( return ConfigDetails(
os.getcwd(), os.getcwd(),
[ConfigFile(None, yaml.safe_load(sys.stdin))]) [ConfigFile(None, yaml.safe_load(sys.stdin))],
)
if filenames: if filenames:
filenames = [os.path.join(base_dir, f) for f in filenames] filenames = [os.path.join(base_dir, f) for f in filenames]
@ -221,7 +223,8 @@ def find(base_dir, filenames):
log.debug("Using configuration files: {}".format(",".join(filenames))) log.debug("Using configuration files: {}".format(",".join(filenames)))
return ConfigDetails( return ConfigDetails(
os.path.dirname(filenames[0]), os.path.dirname(filenames[0]),
[ConfigFile.from_filename(f) for f in filenames]) [ConfigFile.from_filename(f) for f in filenames],
)
def validate_config_version(config_files): def validate_config_version(config_files):
@ -288,6 +291,10 @@ def load(config_details):
""" """
validate_config_version(config_details.config_files) validate_config_version(config_details.config_files)
# load environment in working dir for later use in interpolation
# it is done here to avoid having to pass down working_dir
Environment.get_instance(config_details.working_dir)
processed_files = [ processed_files = [
process_config_file(config_file) process_config_file(config_file)
for config_file in config_details.config_files for config_file in config_details.config_files
@ -302,9 +309,8 @@ def load(config_details):
config_details.config_files, 'get_networks', 'Network' config_details.config_files, 'get_networks', 'Network'
) )
service_dicts = load_services( service_dicts = load_services(
config_details.working_dir, config_details, main_file,
main_file, )
[file.get_service_dicts() for file in config_details.config_files])
if main_file.version != V1: if main_file.version != V1:
for service_dict in service_dicts: for service_dict in service_dicts:
@ -348,14 +354,16 @@ def load_mapping(config_files, get_func, entity_type):
return mapping return mapping
def load_services(working_dir, config_file, service_configs): def load_services(config_details, config_file):
def build_service(service_name, service_dict, service_names): def build_service(service_name, service_dict, service_names):
service_config = ServiceConfig.with_abs_paths( service_config = ServiceConfig.with_abs_paths(
working_dir, config_details.working_dir,
config_file.filename, config_file.filename,
service_name, service_name,
service_dict) service_dict)
resolver = ServiceExtendsResolver(service_config, config_file) resolver = ServiceExtendsResolver(
service_config, config_file
)
service_dict = process_service(resolver.run()) service_dict = process_service(resolver.run())
service_config = service_config._replace(config=service_dict) service_config = service_config._replace(config=service_dict)
@ -383,6 +391,10 @@ def load_services(working_dir, config_file, service_configs):
for name in all_service_names for name in all_service_names
} }
service_configs = [
file.get_service_dicts() for file in config_details.config_files
]
service_config = service_configs[0] service_config = service_configs[0]
for next_config in service_configs[1:]: for next_config in service_configs[1:]:
service_config = merge_services(service_config, next_config) service_config = merge_services(service_config, next_config)
@ -462,8 +474,8 @@ class ServiceExtendsResolver(object):
extends_file = ConfigFile.from_filename(config_path) extends_file = ConfigFile.from_filename(config_path)
validate_config_version([self.config_file, extends_file]) validate_config_version([self.config_file, extends_file])
extended_file = process_config_file( extended_file = process_config_file(
extends_file, extends_file, service_name=service_name
service_name=service_name) )
service_config = extended_file.get_service(service_name) service_config = extended_file.get_service(service_name)
return config_path, service_config, service_name return config_path, service_config, service_name
@ -476,7 +488,8 @@ class ServiceExtendsResolver(object):
service_name, service_name,
service_dict), service_dict),
self.config_file, self.config_file,
already_seen=self.already_seen + [self.signature]) already_seen=self.already_seen + [self.signature],
)
service_config = resolver.run() service_config = resolver.run()
other_service_dict = process_service(service_config) other_service_dict = process_service(service_config)
@ -824,10 +837,11 @@ def parse_ulimits(ulimits):
def resolve_env_var(key, val): def resolve_env_var(key, val):
environment = Environment.get_instance()
if val is not None: if val is not None:
return key, val return key, val
elif key in os.environ: elif key in environment:
return key, os.environ[key] return key, environment[key]
else: else:
return key, None return key, None

View File

@ -0,0 +1,69 @@
from __future__ import absolute_import
from __future__ import unicode_literals
import logging
import os
from .errors import ConfigurationError
log = logging.getLogger(__name__)
class BlankDefaultDict(dict):
def __init__(self, *args, **kwargs):
super(BlankDefaultDict, self).__init__(*args, **kwargs)
self.missing_keys = []
def __getitem__(self, key):
try:
return super(BlankDefaultDict, self).__getitem__(key)
except KeyError:
if key not in self.missing_keys:
log.warn(
"The {} variable is not set. Defaulting to a blank string."
.format(key)
)
self.missing_keys.append(key)
return ""
class Environment(BlankDefaultDict):
__instance = None
@classmethod
def get_instance(cls, base_dir='.'):
if cls.__instance:
return cls.__instance
instance = cls(base_dir)
cls.__instance = instance
return instance
@classmethod
def reset(cls):
cls.__instance = None
def __init__(self, base_dir):
super(Environment, self).__init__()
self.load_environment_file(os.path.join(base_dir, '.env'))
self.update(os.environ)
def load_environment_file(self, path):
if not os.path.exists(path):
return
mapping = {}
with open(path, 'r') as f:
for line in f.readlines():
line = line.strip()
if '=' not in line:
raise ConfigurationError(
'Invalid environment variable mapping in env file. '
'Missing "=" in "{0}"'.format(line)
)
mapping.__setitem__(*line.split('=', 1))
self.update(mapping)
def get_instance(base_dir=None):
return Environment.get_instance(base_dir)

View File

@ -2,17 +2,17 @@ from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
import logging import logging
import os
from string import Template from string import Template
import six import six
from .environment import Environment
from .errors import ConfigurationError from .errors import ConfigurationError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def interpolate_environment_variables(config, section): def interpolate_environment_variables(config, section):
mapping = BlankDefaultDict(os.environ) mapping = Environment.get_instance()
def process_item(name, config_dict): def process_item(name, config_dict):
return dict( return dict(
@ -60,25 +60,6 @@ def interpolate(string, mapping):
raise InvalidInterpolation(string) raise InvalidInterpolation(string)
class BlankDefaultDict(dict):
def __init__(self, *args, **kwargs):
super(BlankDefaultDict, self).__init__(*args, **kwargs)
self.missing_keys = []
def __getitem__(self, key):
try:
return super(BlankDefaultDict, self).__getitem__(key)
except KeyError:
if key not in self.missing_keys:
log.warn(
"The {} variable is not set. Defaulting to a blank string."
.format(key)
)
self.missing_keys.append(key)
return ""
class InvalidInterpolation(Exception): class InvalidInterpolation(Exception):
def __init__(self, string): def __init__(self, string):
self.string = string self.string = string

View File

@ -15,7 +15,7 @@ from operator import attrgetter
import yaml import yaml
from docker import errors from docker import errors
from .. import mock from ..helpers import clear_environment
from compose.cli.command import get_project from compose.cli.command import get_project
from compose.container import Container from compose.container import Container
from compose.project import OneOffFilter from compose.project import OneOffFilter
@ -1452,7 +1452,7 @@ class CLITestCase(DockerClientTestCase):
self.assertEqual(len(containers), 1) self.assertEqual(len(containers), 1)
self.assertIn("FOO=1", containers[0].get('Config.Env')) self.assertIn("FOO=1", containers[0].get('Config.Env'))
@mock.patch.dict(os.environ) @clear_environment
def test_home_and_env_var_in_volume_path(self): def test_home_and_env_var_in_volume_path(self):
os.environ['VOLUME_NAME'] = 'my-volume' os.environ['VOLUME_NAME'] = 'my-volume'
os.environ['HOME'] = '/tmp/home-dir' os.environ['HOME'] = '/tmp/home-dir'

View File

@ -1,9 +1,14 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
import functools
import os
from . import mock
from compose.config.config import ConfigDetails from compose.config.config import ConfigDetails
from compose.config.config import ConfigFile from compose.config.config import ConfigFile
from compose.config.config import load from compose.config.config import load
from compose.config.environment import Environment
def build_config(contents, **kwargs): def build_config(contents, **kwargs):
@ -14,3 +19,11 @@ def build_config_details(contents, working_dir='working_dir', filename='filename
return ConfigDetails( return ConfigDetails(
working_dir, working_dir,
[ConfigFile(filename, contents)]) [ConfigFile(filename, contents)])
def clear_environment(f):
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
Environment.reset()
with mock.patch.dict(os.environ):
f(self, *args, **kwargs)

View File

@ -12,6 +12,7 @@ from six import StringIO
from six import text_type from six import text_type
from .. import mock from .. import mock
from ..helpers import clear_environment
from .testcases import DockerClientTestCase from .testcases import DockerClientTestCase
from .testcases import get_links from .testcases import get_links
from .testcases import pull_busybox from .testcases import pull_busybox
@ -912,7 +913,7 @@ class ServiceTest(DockerClientTestCase):
}.items(): }.items():
self.assertEqual(env[k], v) self.assertEqual(env[k], v)
@mock.patch.dict(os.environ) @clear_environment
def test_resolve_env(self): def test_resolve_env(self):
os.environ['FILE_DEF'] = 'E1' os.environ['FILE_DEF'] = 'E1'
os.environ['FILE_DEF_EMPTY'] = 'E2' os.environ['FILE_DEF_EMPTY'] = 'E2'

View File

@ -11,6 +11,7 @@ import pytest
from .. import mock from .. import mock
from .. import unittest from .. import unittest
from ..helpers import build_config from ..helpers import build_config
from ..helpers import clear_environment
from compose.cli.command import get_project from compose.cli.command import get_project
from compose.cli.command import get_project_name from compose.cli.command import get_project_name
from compose.cli.docopt_command import NoSuchCommand from compose.cli.docopt_command import NoSuchCommand
@ -43,11 +44,11 @@ class CLITestCase(unittest.TestCase):
project_name = get_project_name(None, project_name=name) project_name = get_project_name(None, project_name=name)
self.assertEquals('explicitprojectname', project_name) self.assertEquals('explicitprojectname', project_name)
@clear_environment
def test_project_name_from_environment_new_var(self): def test_project_name_from_environment_new_var(self):
name = 'namefromenv' name = 'namefromenv'
with mock.patch.dict(os.environ): os.environ['COMPOSE_PROJECT_NAME'] = name
os.environ['COMPOSE_PROJECT_NAME'] = name project_name = get_project_name(None)
project_name = get_project_name(None)
self.assertEquals(project_name, name) self.assertEquals(project_name, name)
def test_project_name_with_empty_environment_var(self): def test_project_name_with_empty_environment_var(self):

View File

@ -23,6 +23,7 @@ from compose.config.types import VolumeSpec
from compose.const import IS_WINDOWS_PLATFORM from compose.const import IS_WINDOWS_PLATFORM
from tests import mock from tests import mock
from tests import unittest from tests import unittest
from tests.helpers import clear_environment
DEFAULT_VERSION = V2_0 DEFAULT_VERSION = V2_0
@ -1581,7 +1582,7 @@ class PortsTest(unittest.TestCase):
class InterpolationTest(unittest.TestCase): class InterpolationTest(unittest.TestCase):
@mock.patch.dict(os.environ) @clear_environment
def test_config_file_with_environment_variable(self): def test_config_file_with_environment_variable(self):
os.environ.update( os.environ.update(
IMAGE="busybox", IMAGE="busybox",
@ -1604,7 +1605,7 @@ class InterpolationTest(unittest.TestCase):
} }
]) ])
@mock.patch.dict(os.environ) @clear_environment
def test_unset_variable_produces_warning(self): def test_unset_variable_produces_warning(self):
os.environ.pop('FOO', None) os.environ.pop('FOO', None)
os.environ.pop('BAR', None) os.environ.pop('BAR', None)
@ -1628,7 +1629,7 @@ class InterpolationTest(unittest.TestCase):
self.assertIn('BAR', warnings[0]) self.assertIn('BAR', warnings[0])
self.assertIn('FOO', warnings[1]) self.assertIn('FOO', warnings[1])
@mock.patch.dict(os.environ) @clear_environment
def test_invalid_interpolation(self): def test_invalid_interpolation(self):
with self.assertRaises(config.ConfigurationError) as cm: with self.assertRaises(config.ConfigurationError) as cm:
config.load( config.load(
@ -1667,7 +1668,7 @@ class VolumeConfigTest(unittest.TestCase):
d = make_service_dict('foo', {'build': '.', 'volumes': ['/data']}, working_dir='.') d = make_service_dict('foo', {'build': '.', 'volumes': ['/data']}, working_dir='.')
self.assertEqual(d['volumes'], ['/data']) self.assertEqual(d['volumes'], ['/data'])
@mock.patch.dict(os.environ) @clear_environment
def test_volume_binding_with_environment_variable(self): def test_volume_binding_with_environment_variable(self):
os.environ['VOLUME_PATH'] = '/host/path' os.environ['VOLUME_PATH'] = '/host/path'
@ -1681,7 +1682,7 @@ class VolumeConfigTest(unittest.TestCase):
self.assertEqual(d['volumes'], [VolumeSpec.parse('/host/path:/container/path')]) self.assertEqual(d['volumes'], [VolumeSpec.parse('/host/path:/container/path')])
@pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason='posix paths') @pytest.mark.skipif(IS_WINDOWS_PLATFORM, reason='posix paths')
@mock.patch.dict(os.environ) @clear_environment
def test_volume_binding_with_home(self): def test_volume_binding_with_home(self):
os.environ['HOME'] = '/home/user' os.environ['HOME'] = '/home/user'
d = make_service_dict('foo', {'build': '.', 'volumes': ['~:/container/path']}, working_dir='.') d = make_service_dict('foo', {'build': '.', 'volumes': ['~:/container/path']}, working_dir='.')
@ -1739,7 +1740,7 @@ class VolumeConfigTest(unittest.TestCase):
working_dir='c:\\Users\\me\\myproject') working_dir='c:\\Users\\me\\myproject')
self.assertEqual(d['volumes'], ['c:\\Users\\me\\otherproject:/data']) self.assertEqual(d['volumes'], ['c:\\Users\\me\\otherproject:/data'])
@mock.patch.dict(os.environ) @clear_environment
def test_home_directory_with_driver_does_not_expand(self): def test_home_directory_with_driver_does_not_expand(self):
os.environ['NAME'] = 'surprise!' os.environ['NAME'] = 'surprise!'
d = make_service_dict('foo', { d = make_service_dict('foo', {
@ -2025,7 +2026,7 @@ class EnvTest(unittest.TestCase):
def test_parse_environment_empty(self): def test_parse_environment_empty(self):
self.assertEqual(config.parse_environment(None), {}) self.assertEqual(config.parse_environment(None), {})
@mock.patch.dict(os.environ) @clear_environment
def test_resolve_environment(self): def test_resolve_environment(self):
os.environ['FILE_DEF'] = 'E1' os.environ['FILE_DEF'] = 'E1'
os.environ['FILE_DEF_EMPTY'] = 'E2' os.environ['FILE_DEF_EMPTY'] = 'E2'
@ -2072,7 +2073,7 @@ class EnvTest(unittest.TestCase):
assert 'Couldn\'t find env file' in exc.exconly() assert 'Couldn\'t find env file' in exc.exconly()
assert 'nonexistent.env' in exc.exconly() assert 'nonexistent.env' in exc.exconly()
@mock.patch.dict(os.environ) @clear_environment
def test_resolve_environment_from_env_file_with_empty_values(self): def test_resolve_environment_from_env_file_with_empty_values(self):
os.environ['FILE_DEF'] = 'E1' os.environ['FILE_DEF'] = 'E1'
os.environ['FILE_DEF_EMPTY'] = 'E2' os.environ['FILE_DEF_EMPTY'] = 'E2'
@ -2087,7 +2088,7 @@ class EnvTest(unittest.TestCase):
}, },
) )
@mock.patch.dict(os.environ) @clear_environment
def test_resolve_build_args(self): def test_resolve_build_args(self):
os.environ['env_arg'] = 'value2' os.environ['env_arg'] = 'value2'
@ -2106,7 +2107,7 @@ class EnvTest(unittest.TestCase):
) )
@pytest.mark.xfail(IS_WINDOWS_PLATFORM, reason='paths use slash') @pytest.mark.xfail(IS_WINDOWS_PLATFORM, reason='paths use slash')
@mock.patch.dict(os.environ) @clear_environment
def test_resolve_path(self): def test_resolve_path(self):
os.environ['HOSTENV'] = '/tmp' os.environ['HOSTENV'] = '/tmp'
os.environ['CONTAINERENV'] = '/host/tmp' os.environ['CONTAINERENV'] = '/host/tmp'
@ -2393,7 +2394,7 @@ class ExtendsTest(unittest.TestCase):
assert 'net: container' in excinfo.exconly() assert 'net: container' in excinfo.exconly()
assert 'cannot be extended' in excinfo.exconly() assert 'cannot be extended' in excinfo.exconly()
@mock.patch.dict(os.environ) @clear_environment
def test_load_config_runs_interpolation_in_extended_service(self): def test_load_config_runs_interpolation_in_extended_service(self):
os.environ.update(HOSTNAME_VALUE="penguin") os.environ.update(HOSTNAME_VALUE="penguin")
expected_interpolated_value = "host-penguin" expected_interpolated_value = "host-penguin"
@ -2465,6 +2466,7 @@ class ExtendsTest(unittest.TestCase):
}, },
])) ]))
@clear_environment
def test_extends_with_environment_and_env_files(self): def test_extends_with_environment_and_env_files(self):
tmpdir = py.test.ensuretemp('test_extends_with_environment') tmpdir = py.test.ensuretemp('test_extends_with_environment')
self.addCleanup(tmpdir.remove) self.addCleanup(tmpdir.remove)
@ -2520,12 +2522,12 @@ class ExtendsTest(unittest.TestCase):
}, },
}, },
] ]
with mock.patch.dict(os.environ):
os.environ['SECRET'] = 'secret' os.environ['SECRET'] = 'secret'
os.environ['THING'] = 'thing' os.environ['THING'] = 'thing'
os.environ['COMMON_ENV_FILE'] = 'secret' os.environ['COMMON_ENV_FILE'] = 'secret'
os.environ['TOP_ENV_FILE'] = 'secret' os.environ['TOP_ENV_FILE'] = 'secret'
config = load_from_filename(str(tmpdir.join('docker-compose.yml'))) config = load_from_filename(str(tmpdir.join('docker-compose.yml')))
assert config == expected assert config == expected

View File

@ -6,12 +6,14 @@ import os
import mock import mock
import pytest import pytest
from compose.config.environment import Environment
from compose.config.interpolation import interpolate_environment_variables from compose.config.interpolation import interpolate_environment_variables
@pytest.yield_fixture @pytest.yield_fixture
def mock_env(): def mock_env():
with mock.patch.dict(os.environ): with mock.patch.dict(os.environ):
Environment.reset()
os.environ['USER'] = 'jenny' os.environ['USER'] = 'jenny'
os.environ['FOO'] = 'bar' os.environ['FOO'] = 'bar'
yield yield

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
import unittest import unittest
from compose.config.interpolation import BlankDefaultDict as bddict from compose.config.environment import BlankDefaultDict as bddict
from compose.config.interpolation import interpolate from compose.config.interpolation import interpolate
from compose.config.interpolation import InvalidInterpolation from compose.config.interpolation import InvalidInterpolation