Allow bytes paths in VCS bindings

This commit is contained in:
ZyX 2014-09-14 16:13:25 +04:00
parent 8417fd25e2
commit be7056fd7d
6 changed files with 36 additions and 17 deletions

View File

@ -6,3 +6,13 @@ import os
def realpath(path): def realpath(path):
return os.path.abspath(os.path.realpath(path)) return os.path.abspath(os.path.realpath(path))
def join(*components):
if any((isinstance(p, bytes) for p in components)):
return os.path.join(*[
p if isinstance(p, bytes) else p.encode('ascii')
for p in components
])
else:
return os.path.join(*components)

View File

@ -8,6 +8,7 @@ from threading import Lock
from collections import defaultdict from collections import defaultdict
from powerline.lib.watcher import create_tree_watcher from powerline.lib.watcher import create_tree_watcher
from powerline.lib.unicode import out_u
def generate_directories(path): def generate_directories(path):
@ -75,10 +76,10 @@ def get_branch_name(directory, config_file, get_func, create_watcher):
raise raise
# Config file does not exist (happens for mercurial) # Config file does not exist (happens for mercurial)
if config_file not in branch_name_cache: if config_file not in branch_name_cache:
branch_name_cache[config_file] = get_func(directory, config_file) branch_name_cache[config_file] = out_u(get_func(directory, config_file))
if changed: if changed:
# Config file has changed or was not tracked # Config file has changed or was not tracked
branch_name_cache[config_file] = get_func(directory, config_file) branch_name_cache[config_file] = out_u(get_func(directory, config_file))
return branch_name_cache[config_file] return branch_name_cache[config_file]
@ -218,9 +219,15 @@ vcs_props = (
) )
vcs_props_bytes = [
(vcs, vcs_dir.encode('ascii'), check)
for vcs, vcs_dir, check in vcs_props
]
def guess(path, create_watcher): def guess(path, create_watcher):
for directory in generate_directories(path): for directory in generate_directories(path):
for vcs, vcs_dir, check in vcs_props: for vcs, vcs_dir, check in (vcs_props_bytes if isinstance(path, bytes) else vcs_props):
repo_dir = os.path.join(directory, vcs_dir) repo_dir = os.path.join(directory, vcs_dir)
if check(repo_dir): if check(repo_dir):
if os.path.isdir(repo_dir) and not os.access(repo_dir, os.X_OK): if os.path.isdir(repo_dir) and not os.access(repo_dir, os.X_OK):

View File

@ -1,7 +1,6 @@
# vim:fileencoding=utf-8:noet # vim:fileencoding=utf-8:noet
from __future__ import (unicode_literals, division, absolute_import, print_function) from __future__ import (unicode_literals, division, absolute_import, print_function)
import sys
import os import os
import re import re
@ -11,6 +10,7 @@ from locale import getpreferredencoding
from bzrlib import (workingtree, status, library_state, trace, ui) from bzrlib import (workingtree, status, library_state, trace, ui)
from powerline.lib.vcs import get_branch_name, get_file_status from powerline.lib.vcs import get_branch_name, get_file_status
from powerline.lib.path import join
class CoerceIO(StringIO): class CoerceIO(StringIO):
@ -42,8 +42,6 @@ state = None
class Repository(object): class Repository(object):
def __init__(self, directory, create_watcher): def __init__(self, directory, create_watcher):
if isinstance(directory, bytes):
directory = directory.decode(sys.getfilesystemencoding() or sys.getdefaultencoding() or 'utf-8')
self.directory = os.path.abspath(directory) self.directory = os.path.abspath(directory)
self.create_watcher = create_watcher self.create_watcher = create_watcher
@ -62,7 +60,7 @@ class Repository(object):
if path is not None: if path is not None:
return get_file_status( return get_file_status(
directory=self.directory, directory=self.directory,
dirstate_file=os.path.join(self.directory, '.bzr', 'checkout', 'dirstate'), dirstate_file=join(self.directory, '.bzr', 'checkout', 'dirstate'),
file_path=path, file_path=path,
ignore_file_name='.bzrignore', ignore_file_name='.bzrignore',
get_func=self.do_status, get_func=self.do_status,
@ -101,7 +99,7 @@ class Repository(object):
return ans if ans.strip() else None return ans if ans.strip() else None
def branch(self): def branch(self):
config_file = os.path.join(self.directory, '.bzr', 'branch', 'branch.conf') config_file = join(self.directory, '.bzr', 'branch', 'branch.conf')
return get_branch_name( return get_branch_name(
directory=self.directory, directory=self.directory,
config_file=config_file, config_file=config_file,

View File

@ -9,6 +9,7 @@ from locale import getpreferredencoding
from powerline.lib.vcs import get_branch_name, get_file_status from powerline.lib.vcs import get_branch_name, get_file_status
from powerline.lib.shell import readlines from powerline.lib.shell import readlines
from powerline.lib.path import join
_ref_pat = re.compile(br'ref:\s*refs/heads/(.+)') _ref_pat = re.compile(br'ref:\s*refs/heads/(.+)')
@ -27,15 +28,17 @@ def branch_name_from_config_file(directory, config_file):
def git_directory(directory): def git_directory(directory):
path = os.path.join(directory, '.git') path = join(directory, '.git')
if os.path.isfile(path): if os.path.isfile(path):
with open(path, 'rb') as f: with open(path, 'rb') as f:
raw = f.read() raw = f.read()
if not raw.startswith(b'gitdir: '): if not raw.startswith(b'gitdir: '):
raise IOError('invalid gitfile format') raise IOError('invalid gitfile format')
raw = raw[8:].decode(sys.getfilesystemencoding() or 'utf-8') raw = raw[8:]
if raw[-1] == '\n': if raw[-1:] == b'\n':
raw = raw[:-1] raw = raw[:-1]
if not isinstance(path, bytes):
raw = raw.decode(sys.getfilesystemencoding() or 'utf-8')
if not raw: if not raw:
raise IOError('no path in gitfile') raise IOError('no path in gitfile')
return os.path.abspath(os.path.join(directory, raw)) return os.path.abspath(os.path.join(directory, raw))
@ -71,18 +74,18 @@ class GitRepository(object):
# for some reason I cannot be bothered to figure out. # for some reason I cannot be bothered to figure out.
return get_file_status( return get_file_status(
directory=self.directory, directory=self.directory,
dirstate_file=os.path.join(gitd, 'index'), dirstate_file=join(gitd, 'index'),
file_path=path, file_path=path,
ignore_file_name='.gitignore', ignore_file_name='.gitignore',
get_func=self.do_status, get_func=self.do_status,
create_watcher=self.create_watcher, create_watcher=self.create_watcher,
extra_ignore_files=tuple(os.path.join(gitd, x) for x in ('logs/HEAD', 'info/exclude')), extra_ignore_files=tuple(join(gitd, x) for x in ('logs/HEAD', 'info/exclude')),
) )
return self.do_status(self.directory, path) return self.do_status(self.directory, path)
def branch(self): def branch(self):
directory = git_directory(self.directory) directory = git_directory(self.directory)
head = os.path.join(directory, 'HEAD') head = join(directory, 'HEAD')
return get_branch_name( return get_branch_name(
directory=directory, directory=directory,
config_file=head, config_file=head,

View File

@ -8,6 +8,7 @@ from locale import getpreferredencoding
from mercurial import hg, ui, match from mercurial import hg, ui, match
from powerline.lib.vcs import get_branch_name, get_file_status from powerline.lib.vcs import get_branch_name, get_file_status
from powerline.lib.path import join
def branch_name_from_config_file(directory, config_file): def branch_name_from_config_file(directory, config_file):
@ -52,7 +53,7 @@ class Repository(object):
if path: if path:
return get_file_status( return get_file_status(
directory=self.directory, directory=self.directory,
dirstate_file=os.path.join(self.directory, '.hg', 'dirstate'), dirstate_file=join(self.directory, '.hg', 'dirstate'),
file_path=path, file_path=path,
ignore_file_name='.hgignore', ignore_file_name='.hgignore',
get_func=self.do_status, get_func=self.do_status,
@ -77,7 +78,7 @@ class Repository(object):
return self.repo_statuses_str[resulting_status] return self.repo_statuses_str[resulting_status]
def branch(self): def branch(self):
config_file = os.path.join(self.directory, '.hg', 'branch') config_file = join(self.directory, '.hg', 'branch')
return get_branch_name( return get_branch_name(
directory=self.directory, directory=self.directory,
config_file=config_file, config_file=config_file,

View File

@ -445,7 +445,7 @@ class TestVCS(TestCase):
call(['git', 'checkout', '-q', 'branch2'], cwd=GIT_REPO) call(['git', 'checkout', '-q', 'branch2'], cwd=GIT_REPO)
self.do_branch_rename_test(repo, 'branch2') self.do_branch_rename_test(repo, 'branch2')
call(['git', 'checkout', '-q', '--detach', 'branch1'], cwd=GIT_REPO) call(['git', 'checkout', '-q', '--detach', 'branch1'], cwd=GIT_REPO)
self.do_branch_rename_test(repo, lambda b: re.match(br'^[a-f0-9]+$', b)) self.do_branch_rename_test(repo, lambda b: re.match(r'^[a-f0-9]+$', b))
finally: finally:
call(['git', 'checkout', '-q', 'master'], cwd=GIT_REPO) call(['git', 'checkout', '-q', 'master'], cwd=GIT_REPO)