Environment class cleanup

Signed-off-by: Joffrey F <joffrey@docker.com>
This commit is contained in:
Joffrey F 2016-03-08 16:54:14 -08:00
parent fd020ed2cf
commit 1801f83bb8
7 changed files with 58 additions and 66 deletions

View File

@ -34,7 +34,7 @@ def get_config_path_from_options(base_dir, options):
if file_option: if file_option:
return file_option return file_option
environment = config.environment.Environment(base_dir) environment = config.environment.Environment.from_env_file(base_dir)
config_files = environment.get('COMPOSE_FILE') config_files = environment.get('COMPOSE_FILE')
if config_files: if config_files:
return config_files.split(os.pathsep) return config_files.split(os.pathsep)
@ -58,7 +58,7 @@ def get_project(project_dir, config_path=None, project_name=None, verbose=False,
config_details = config.find(project_dir, config_path) config_details = config.find(project_dir, config_path)
project_name = get_project_name(config_details.working_dir, project_name) project_name = get_project_name(config_details.working_dir, project_name)
config_data = config.load(config_details) config_data = config.load(config_details)
environment = config.environment.Environment(project_dir) environment = config.environment.Environment.from_env_file(project_dir)
api_version = environment.get( api_version = environment.get(
'COMPOSE_API_VERSION', 'COMPOSE_API_VERSION',
@ -75,7 +75,7 @@ def get_project_name(working_dir, project_name=None):
def normalize_name(name): def normalize_name(name):
return re.sub(r'[^a-z0-9]', '', name.lower()) return re.sub(r'[^a-z0-9]', '', name.lower())
environment = config.environment.Environment(working_dir) environment = config.environment.Environment.from_env_file(working_dir)
project_name = project_name or environment.get('COMPOSE_PROJECT_NAME') project_name = project_name or environment.get('COMPOSE_PROJECT_NAME')
if project_name: if project_name:
return normalize_name(project_name) return normalize_name(project_name)

View File

@ -1,7 +1,6 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
import codecs
import functools import functools
import logging import logging
import operator import operator
@ -17,7 +16,9 @@ from cached_property import cached_property
from ..const import COMPOSEFILE_V1 as V1 from ..const import COMPOSEFILE_V1 as V1
from ..const import COMPOSEFILE_V2_0 as V2_0 from ..const import COMPOSEFILE_V2_0 as V2_0
from ..utils import build_string_dict from ..utils import build_string_dict
from .environment import env_vars_from_file
from .environment import Environment from .environment import Environment
from .environment import split_env
from .errors import CircularReference from .errors import CircularReference
from .errors import ComposeFileNotFound from .errors import ComposeFileNotFound
from .errors import ConfigurationError from .errors import ConfigurationError
@ -129,7 +130,7 @@ class ConfigDetails(namedtuple('_ConfigDetails', 'working_dir config_files envir
cls, cls,
working_dir, working_dir,
config_files, config_files,
Environment(working_dir), Environment.from_env_file(working_dir),
) )
@ -314,9 +315,7 @@ def load(config_details):
networks = load_mapping( networks = load_mapping(
config_details.config_files, 'get_networks', 'Network' config_details.config_files, 'get_networks', 'Network'
) )
service_dicts = load_services( service_dicts = load_services(config_details, main_file)
config_details, main_file,
)
if main_file.version != V1: if main_file.version != V1:
for service_dict in service_dicts: for service_dict in service_dicts:
@ -455,7 +454,7 @@ class ServiceExtendsResolver(object):
self.working_dir = service_config.working_dir self.working_dir = service_config.working_dir
self.already_seen = already_seen or [] self.already_seen = already_seen or []
self.config_file = config_file self.config_file = config_file
self.environment = environment or Environment(None) self.environment = environment or Environment()
@property @property
def signature(self): def signature(self):
@ -802,15 +801,6 @@ def merge_environment(base, override):
return env return env
def split_env(env):
if isinstance(env, six.binary_type):
env = env.decode('utf-8', 'replace')
if '=' in env:
return env.split('=', 1)
else:
return env, None
def split_label(label): def split_label(label):
if '=' in label: if '=' in label:
return label.split('=', 1) return label.split('=', 1)
@ -857,21 +847,6 @@ def resolve_env_var(key, val, environment):
return key, None return key, None
def env_vars_from_file(filename):
"""
Read in a line delimited file of environment variables.
"""
if not os.path.exists(filename):
raise ConfigurationError("Couldn't find env file: %s" % filename)
env = {}
for line in codecs.open(filename, 'r', 'utf-8'):
line = line.strip()
if line and not line.startswith('#'):
k, v = split_env(line)
env[k] = v
return env
def resolve_volume_paths(working_dir, service_dict): def resolve_volume_paths(working_dir, service_dict):
return [ return [
resolve_volume_path(working_dir, volume) resolve_volume_path(working_dir, volume)

View File

@ -1,22 +1,62 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import unicode_literals from __future__ import unicode_literals
import codecs
import logging import logging
import os import os
import six
from .errors import ConfigurationError from .errors import ConfigurationError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class BlankDefaultDict(dict): def split_env(env):
if isinstance(env, six.binary_type):
env = env.decode('utf-8', 'replace')
if '=' in env:
return env.split('=', 1)
else:
return env, None
def env_vars_from_file(filename):
"""
Read in a line delimited file of environment variables.
"""
if not os.path.exists(filename):
raise ConfigurationError("Couldn't find env file: %s" % filename)
env = {}
for line in codecs.open(filename, 'r', 'utf-8'):
line = line.strip()
if line and not line.startswith('#'):
k, v = split_env(line)
env[k] = v
return env
class Environment(dict):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(BlankDefaultDict, self).__init__(*args, **kwargs) super(Environment, self).__init__(*args, **kwargs)
self.missing_keys = [] self.missing_keys = []
self.update(os.environ)
@classmethod
def from_env_file(cls, base_dir):
result = cls()
if base_dir is None:
return result
env_file_path = os.path.join(base_dir, '.env')
try:
result.update(env_vars_from_file(env_file_path))
except ConfigurationError:
pass
return result
def __getitem__(self, key): def __getitem__(self, key):
try: try:
return super(BlankDefaultDict, self).__getitem__(key) return super(Environment, self).__getitem__(key)
except KeyError: except KeyError:
if key not in self.missing_keys: if key not in self.missing_keys:
log.warn( log.warn(
@ -26,26 +66,3 @@ class BlankDefaultDict(dict):
self.missing_keys.append(key) self.missing_keys.append(key)
return "" return ""
class Environment(BlankDefaultDict):
def __init__(self, base_dir):
super(Environment, self).__init__()
if base_dir:
self.load_environment_file(os.path.join(base_dir, '.env'))
self.update(os.environ)
def load_environment_file(self, path):
if not os.path.exists(path):
return
mapping = {}
with open(path, 'r') as f:
for line in f.readlines():
line = line.strip()
if '=' not in line:
raise ConfigurationError(
'Invalid environment variable mapping in env file. '
'Missing "=" in "{0}"'.format(line)
)
mapping.__setitem__(*line.split('=', 1))
self.update(mapping)

View File

@ -90,7 +90,7 @@ class DockerClientTestCase(unittest.TestCase):
if 'command' not in kwargs: if 'command' not in kwargs:
kwargs['command'] = ["top"] kwargs['command'] = ["top"]
kwargs['environment'] = resolve_environment(kwargs, Environment(None)) kwargs['environment'] = resolve_environment(kwargs, Environment())
labels = dict(kwargs.setdefault('labels', {})) labels = dict(kwargs.setdefault('labels', {}))
labels['com.docker.compose.test-name'] = self.id() labels['com.docker.compose.test-name'] = self.id()

View File

@ -2042,7 +2042,7 @@ class EnvTest(unittest.TestCase):
}, },
} }
self.assertEqual( self.assertEqual(
resolve_environment(service_dict, Environment(None)), resolve_environment(service_dict, Environment()),
{'FILE_DEF': 'F1', 'FILE_DEF_EMPTY': '', 'ENV_DEF': 'E3', 'NO_DEF': None}, {'FILE_DEF': 'F1', 'FILE_DEF_EMPTY': '', 'ENV_DEF': 'E3', 'NO_DEF': None},
) )
@ -2080,7 +2080,7 @@ class EnvTest(unittest.TestCase):
os.environ['ENV_DEF'] = 'E3' os.environ['ENV_DEF'] = 'E3'
self.assertEqual( self.assertEqual(
resolve_environment( resolve_environment(
{'env_file': ['tests/fixtures/env/resolve.env']}, Environment(None) {'env_file': ['tests/fixtures/env/resolve.env']}, Environment()
), ),
{ {
'FILE_DEF': u'bär', 'FILE_DEF': u'bär',
@ -2104,7 +2104,7 @@ class EnvTest(unittest.TestCase):
} }
} }
self.assertEqual( self.assertEqual(
resolve_build_args(build, Environment(build['context'])), resolve_build_args(build, Environment.from_env_file(build['context'])),
{'arg1': 'value1', 'empty_arg': '', 'env_arg': 'value2', 'no_env': None}, {'arg1': 'value1', 'empty_arg': '', 'env_arg': 'value2', 'no_env': None},
) )

View File

@ -44,7 +44,7 @@ def test_interpolate_environment_variables_in_services(mock_env):
} }
} }
assert interpolate_environment_variables( assert interpolate_environment_variables(
services, 'service', Environment(None) services, 'service', Environment()
) == expected ) == expected
@ -70,5 +70,5 @@ def test_interpolate_environment_variables_in_volumes(mock_env):
'other': {}, 'other': {},
} }
assert interpolate_environment_variables( assert interpolate_environment_variables(
volumes, 'volume', Environment(None) volumes, 'volume', Environment()
) == expected ) == expected

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
import unittest import unittest
from compose.config.environment import BlankDefaultDict as bddict from compose.config.environment import Environment as bddict
from compose.config.interpolation import interpolate from compose.config.interpolation import interpolate
from compose.config.interpolation import InvalidInterpolation from compose.config.interpolation import InvalidInterpolation