diff --git a/powerline/segments/common/env.py b/powerline/segments/common/env.py index 8041e500..1aa991af 100644 --- a/powerline/segments/common/env.py +++ b/powerline/segments/common/env.py @@ -122,15 +122,20 @@ try: # psutil-2.0.0: psutil.Process.username is unbound method if callable(psutil.Process.username): - def _get_user(segment_info): + def _get_user(): return psutil.Process(os.getpid()).username() # pre psutil-2.0.0: psutil.Process.username has type property else: - def _get_user(segment_info): + def _get_user(): return psutil.Process(os.getpid()).username except ImportError: - def _get_user(segment_info): - return segment_info['environ'].get('USER', None) + try: + import pwd + except ImportError: + from getpass import getuser as _get_user + else: + def _get_user(): + return pwd.getpwuid(os.geteuid()).pw_name username = False @@ -138,7 +143,7 @@ username = False _geteuid = getattr(os, 'geteuid', lambda: 1) -def user(pl, segment_info=None, hide_user=None): +def user(pl, hide_user=None): '''Return the current user. :param str hide_user: @@ -150,7 +155,7 @@ def user(pl, segment_info=None, hide_user=None): ''' global username if username is False: - username = _get_user(segment_info) + username = _get_user() if username is None: pl.warn('Failed to get username') return None diff --git a/tests/test_segments.py b/tests/test_segments.py index 3ee3011b..57f23abb 100644 --- a/tests/test_segments.py +++ b/tests/test_segments.py @@ -5,6 +5,7 @@ import sys import os from functools import partial +from collections import namedtuple from powerline.segments import shell, tmux, common from powerline.lib.vcs import get_fallback_create_watcher @@ -451,24 +452,27 @@ class TestEnv(TestCommon): if hasattr(self.module, 'psutil') and not callable(self.module.psutil.Process.username): username = property(username) + struct_passwd = namedtuple('struct_passwd', ('pw_name',)) new_psutil = new_module('psutil', Process=Process) + new_pwd = new_module('pwd', getpwuid=lambda uid: struct_passwd(pw_name='def')) + new_getpass = new_module('getpass', getuser=lambda: 'def') pl = Pl() - with replace_env('USER', 'def') as segment_info: - common.username = False - with replace_attr(self.module, 'os', new_os): - with replace_attr(self.module, 'psutil', new_psutil): - with replace_attr(self.module, '_geteuid', lambda: 5): - self.assertEqual(common.user(pl=pl, segment_info=segment_info), [ - {'contents': 'def', 'highlight_group': ['user']} - ]) - self.assertEqual(common.user(pl=pl, segment_info=segment_info, hide_user='abc'), [ - {'contents': 'def', 'highlight_group': ['user']} - ]) - self.assertEqual(common.user(pl=pl, segment_info=segment_info, hide_user='def'), None) - with replace_attr(self.module, '_geteuid', lambda: 0): - self.assertEqual(common.user(pl=pl, segment_info=segment_info), [ - {'contents': 'def', 'highlight_group': ['superuser', 'user']} - ]) + with replace_attr(self.module, 'pwd', new_pwd): + with replace_attr(self.module, 'getpass', new_getpass): + with replace_attr(self.module, 'os', new_os): + with replace_attr(self.module, 'psutil', new_psutil): + with replace_attr(self.module, '_geteuid', lambda: 5): + self.assertEqual(common.user(pl=pl), [ + {'contents': 'def', 'highlight_group': ['user']} + ]) + self.assertEqual(common.user(pl=pl, hide_user='abc'), [ + {'contents': 'def', 'highlight_group': ['user']} + ]) + self.assertEqual(common.user(pl=pl, hide_user='def'), None) + with replace_attr(self.module, '_geteuid', lambda: 0): + self.assertEqual(common.user(pl=pl), [ + {'contents': 'def', 'highlight_group': ['superuser', 'user']} + ]) def test_cwd(self): new_os = new_module('os', path=os.path, sep='/')