diff --git a/powerline/segments/common.py b/powerline/segments/common.py index de8efd37..d1a30d9b 100644 --- a/powerline/segments/common.py +++ b/powerline/segments/common.py @@ -20,14 +20,18 @@ from collections import namedtuple @requires_segment_info -def hostname(pl, segment_info, only_if_ssh=False): +def hostname(pl, segment_info, only_if_ssh=False, exclude_domain=False): '''Return the current hostname. :param bool only_if_ssh: only return the hostname if currently in an SSH session + :param bool exclude_domain: + return the hostname without domain if there is one ''' if only_if_ssh and not segment_info['environ'].get('SSH_CLIENT'): return None + if exclude_domain: + return socket.gethostname().split('.')[0] return socket.gethostname() diff --git a/tests/test_segments.py b/tests/test_segments.py index 73cf2fa6..40cc7cd0 100644 --- a/tests/test_segments.py +++ b/tests/test_segments.py @@ -42,9 +42,19 @@ class TestCommon(TestCase): with replace_module_module(common, 'socket', gethostname=lambda: 'abc'): self.assertEqual(common.hostname(pl=pl, segment_info=segment_info), 'abc') self.assertEqual(common.hostname(pl=pl, segment_info=segment_info, only_if_ssh=True), 'abc') - segment_info['environ'].pop('SSH_CLIENT') + with replace_module_module(common, 'socket', gethostname=lambda: 'abc.mydomain'): + self.assertEqual(common.hostname(pl=pl, segment_info=segment_info), 'abc.mydomain') + self.assertEqual(common.hostname(pl=pl, segment_info=segment_info, exclude_domain=True), 'abc') + self.assertEqual(common.hostname(pl=pl, segment_info=segment_info, only_if_ssh=True), 'abc.mydomain') + self.assertEqual(common.hostname(pl=pl, segment_info=segment_info, only_if_ssh=True, exclude_domain=True), 'abc') + segment_info['environ'].pop('SSH_CLIENT') + with replace_module_module(common, 'socket', gethostname=lambda: 'abc'): self.assertEqual(common.hostname(pl=pl, segment_info=segment_info), 'abc') self.assertEqual(common.hostname(pl=pl, segment_info=segment_info, only_if_ssh=True), None) + with replace_module_module(common, 'socket', gethostname=lambda: 'abc.mydomain'): + self.assertEqual(common.hostname(pl=pl, segment_info=segment_info), 'abc.mydomain') + self.assertEqual(common.hostname(pl=pl, segment_info=segment_info, exclude_domain=True), 'abc') + self.assertEqual(common.hostname(pl=pl, segment_info=segment_info, only_if_ssh=True, exclude_domain=True), None) def test_user(self): new_os = new_module('os', getpid=lambda: 1)