Remove mutable global variables from daemon script

This commit is contained in:
Foo 2016-08-20 20:57:42 +03:00
parent 03e63fc8d2
commit 25657089db
1 changed files with 88 additions and 85 deletions

View File

@ -6,6 +6,9 @@ import socket
import os import os
import errno import errno
import sys import sys
import fcntl
import atexit
import stat
from argparse import ArgumentParser from argparse import ArgumentParser
from select import select from select import select
@ -15,6 +18,7 @@ from functools import partial
from io import BytesIO from io import BytesIO
from threading import Event from threading import Event
from itertools import chain from itertools import chain
from logging import StreamHandler
from powerline.shell import ShellPowerline from powerline.shell import ShellPowerline
from powerline.commands.main import finish_args, write_output from powerline.commands.main import finish_args, write_output
@ -26,12 +30,7 @@ from powerline.commands.main import get_argparser as get_main_argparser
from powerline.commands.daemon import get_argparser as get_daemon_argparser from powerline.commands.daemon import get_argparser as get_daemon_argparser
is_daemon = False USE_FILESYSTEM = not sys.platform.lower().startswith('linux')
use_filesystem = not sys.platform.lower().startswith('linux')
address = None
pidfile = None
ts_shutdown_event = Event()
class NonInteractiveArgParser(ArgumentParser): class NonInteractiveArgParser(ArgumentParser):
@ -48,47 +47,48 @@ class NonInteractiveArgParser(ArgumentParser):
raise Exception(self.format_usage()) raise Exception(self.format_usage())
parser = get_main_argparser(NonInteractiveArgParser)
EOF = b'EOF\0\0' EOF = b'EOF\0\0'
powerlines = {}
logger = None class State(object):
config_loader = None __slots__ = ('powerlines', 'logger', 'config_loader', 'started_wm_threads',
home = os.path.expanduser('~') 'ts_shutdown_event')
started_wm_threads = {}
def __init__(self, **kwargs):
self.logger = None
self.config_loader = None
self.started_wm_threads = {}
self.powerlines = {}
self.ts_shutdown_event = Event()
class PowerlineDaemon(ShellPowerline): HOME = os.path.expanduser('~')
class NonDaemonShellPowerline(ShellPowerline):
def get_log_handler(self): def get_log_handler(self):
if not is_daemon: return StreamHandler()
import logging
return logging.StreamHandler()
return super(PowerlineDaemon, self).get_log_handler()
def start_wm(args, environ, cwd): def start_wm(args, environ, cwd, is_daemon, state):
wm_name = args.ext[0][3:] wm_name = args.ext[0][3:]
if wm_name in started_wm_threads: if wm_name in state.started_wm_threads:
return b'' return b''
thread_shutdown_event = Event() thread_shutdown_event = Event()
thread = wm_threads[wm_name]( thread = wm_threads[wm_name](
thread_shutdown_event=thread_shutdown_event, thread_shutdown_event=thread_shutdown_event,
pl_shutdown_event=ts_shutdown_event, pl_shutdown_event=state.ts_shutdown_event,
pl_config_loader=config_loader, pl_config_loader=state.config_loader,
) )
thread.start() thread.start()
started_wm_threads[wm_name] = (thread, thread_shutdown_event) state.started_wm_threads[wm_name] = (thread, thread_shutdown_event)
return b'' return b''
def render(args, environ, cwd): def render(args, environ, cwd, is_daemon, state):
global logger
global config_loader
segment_info = { segment_info = {
'getcwd': lambda: cwd, 'getcwd': lambda: cwd,
'home': environ.get('HOME', home), 'home': environ.get('HOME', HOME),
'environ': environ, 'environ': environ,
'args': args, 'args': args,
} }
@ -103,22 +103,23 @@ def render(args, environ, cwd):
environ.get('POWERLINE_CONFIG_PATHS', ''), environ.get('POWERLINE_CONFIG_PATHS', ''),
) )
PowerlineClass = ShellPowerline if is_daemon else NonDaemonShellPowerline
powerline = None powerline = None
try: try:
powerline = powerlines[key] powerline = state.powerlines[key]
except KeyError: except KeyError:
try: try:
powerline = powerlines[key] = PowerlineDaemon( powerline = state.powerlines[key] = PowerlineClass(
args, args,
logger=logger, logger=state.logger,
config_loader=config_loader, config_loader=state.config_loader,
run_once=False, run_once=False,
shutdown_event=ts_shutdown_event, shutdown_event=state.ts_shutdown_event,
) )
if logger is None: if state.logger is None:
logger = powerline.logger state.logger = powerline.logger
if config_loader is None: if state.config_loader is None:
config_loader = powerline.config_loader state.config_loader = powerline.config_loader
except SystemExit: except SystemExit:
# Somebody thought raising system exit was a good idea, # Somebody thought raising system exit was a good idea,
return '' return ''
@ -189,7 +190,7 @@ def safe_bytes(o, encoding=get_preferred_output_encoding()):
return safe_bytes(str(e), encoding) return safe_bytes(str(e), encoding)
def parse_args(req, encoding=get_preferred_arguments_encoding()): def parse_args(req, parser, encoding=get_preferred_arguments_encoding()):
args = [x.decode(encoding) for x in req.split(b'\0') if x] args = [x.decode(encoding) for x in req.split(b'\0') if x]
numargs = int(args[0], 16) numargs = int(args[0], 16)
shell_args = parser.parse_args(args[1:numargs + 1]) shell_args = parser.parse_args(args[1:numargs + 1])
@ -199,19 +200,20 @@ def parse_args(req, encoding=get_preferred_arguments_encoding()):
return shell_args, environ, cwd return shell_args, environ, cwd
def get_answer(req): def get_answer(req, is_daemon, argparser, state):
try: try:
args, environ, cwd = parse_args(req) args, environ, cwd = parse_args(req, argparser)
finish_args(parser, environ, args, is_daemon=True) finish_args(argparser, environ, args, is_daemon=True)
if args.ext[0].startswith('wm.'): if args.ext[0].startswith('wm.'):
return safe_bytes(start_wm(args, environ, cwd)) return safe_bytes(start_wm(args, environ, cwd, is_daemon, state))
else: else:
return safe_bytes(render(args, environ, cwd)) return safe_bytes(render(args, environ, cwd, is_daemon, state))
except Exception as e: except Exception as e:
return safe_bytes(str(e)) return safe_bytes(str(e))
def do_one(sock, read_sockets, write_sockets, result_map): def do_one(sock, read_sockets, write_sockets, result_map, is_daemon, argparser,
state):
r, w, e = select( r, w, e = select(
tuple(read_sockets) + (sock,), tuple(read_sockets) + (sock,),
tuple(write_sockets), tuple(write_sockets),
@ -241,7 +243,7 @@ def do_one(sock, read_sockets, write_sockets, result_map):
if req == EOF: if req == EOF:
raise SystemExit(0) raise SystemExit(0)
elif req: elif req:
ans = get_answer(req) ans = get_answer(req, is_daemon, argparser, state)
result_map[s] = ans result_map[s] = ans
write_sockets.add(s) write_sockets.add(s)
else: else:
@ -257,7 +259,7 @@ def do_one(sock, read_sockets, write_sockets, result_map):
s.close() s.close()
def shutdown(sock, read_sockets, write_sockets): def shutdown(sock, read_sockets, write_sockets, state):
'''Perform operations necessary for nicely shutting down daemon '''Perform operations necessary for nicely shutting down daemon
Specifically it Specifically it
@ -277,11 +279,11 @@ def shutdown(sock, read_sockets, write_sockets):
s.close() s.close()
# Notify ThreadedSegments # Notify ThreadedSegments
ts_shutdown_event.set() state.ts_shutdown_event.set()
for thread, shutdown_event in started_wm_threads.values(): for thread, shutdown_event in state.started_wm_threads.values():
shutdown_event.set() shutdown_event.set()
for thread, shutdown_event in started_wm_threads.values(): for thread, shutdown_event in state.started_wm_threads.values():
wait_time = total_wait_time - (monotonic() - shutdown_start_time) wait_time = total_wait_time - (monotonic() - shutdown_start_time)
if wait_time > 0: if wait_time > 0:
thread.join(wait_time) thread.join(wait_time)
@ -290,20 +292,27 @@ def shutdown(sock, read_sockets, write_sockets):
sleep(wait_time) sleep(wait_time)
def main_loop(sock): def main_loop(sock, is_daemon):
sock.listen(128) sock.listen(128)
sock.setblocking(0) sock.setblocking(0)
read_sockets, write_sockets = set(), set() read_sockets, write_sockets = set(), set()
result_map = {} result_map = {}
parser = get_main_argparser(NonInteractiveArgParser)
state = State()
try: try:
try: try:
while True: while True:
do_one(sock, read_sockets, write_sockets, result_map) do_one(
sock, read_sockets, write_sockets, result_map,
is_daemon=is_daemon,
argparser=parser,
state=state,
)
except KeyboardInterrupt: except KeyboardInterrupt:
raise SystemExit(0) raise SystemExit(0)
except SystemExit as e: except SystemExit as e:
shutdown(sock, read_sockets, write_sockets) shutdown(sock, read_sockets, write_sockets, state)
raise e raise e
return 0 return 0
@ -313,10 +322,10 @@ def daemonize(stdin=os.devnull, stdout=os.devnull, stderr=os.devnull):
pid = os.fork() pid = os.fork()
if pid > 0: if pid > 0:
# exit first parent # exit first parent
sys.exit(0) raise SystemExit(0)
except OSError as e: except OSError as e:
sys.stderr.write("fork #1 failed: %d (%s)\n" % (e.errno, e.strerror)) sys.stderr.write("fork #1 failed: %d (%s)\n" % (e.errno, e.strerror))
sys.exit(1) raise SystemExit(1)
# decouple from parent environment # decouple from parent environment
os.chdir("/") os.chdir("/")
@ -328,10 +337,10 @@ def daemonize(stdin=os.devnull, stdout=os.devnull, stderr=os.devnull):
pid = os.fork() pid = os.fork()
if pid > 0: if pid > 0:
# exit from second parent # exit from second parent
sys.exit(0) raise SystemExit(0)
except OSError as e: except OSError as e:
sys.stderr.write("fork #2 failed: %d (%s)\n" % (e.errno, e.strerror)) sys.stderr.write("fork #2 failed: %d (%s)\n" % (e.errno, e.strerror))
sys.exit(1) raise SystemExit(1)
# Redirect standard file descriptors. # Redirect standard file descriptors.
si = open(stdin, 'rb') si = open(stdin, 'rb')
@ -340,12 +349,11 @@ def daemonize(stdin=os.devnull, stdout=os.devnull, stderr=os.devnull):
os.dup2(si.fileno(), sys.stdin.fileno()) os.dup2(si.fileno(), sys.stdin.fileno())
os.dup2(so.fileno(), sys.stdout.fileno()) os.dup2(so.fileno(), sys.stdout.fileno())
os.dup2(se.fileno(), sys.stderr.fileno()) os.dup2(se.fileno(), sys.stderr.fileno())
global is_daemon return True
is_daemon = True
def check_existing(): def check_existing(address):
if use_filesystem: if USE_FILESYSTEM:
# We cannot bind if the socket file already exists so remove it, we # We cannot bind if the socket file already exists so remove it, we
# already have a lock on pidfile, so this should be safe. # already have a lock on pidfile, so this should be safe.
try: try:
@ -363,7 +371,7 @@ def check_existing():
return sock return sock
def kill_daemon(): def kill_daemon(address):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try: try:
try: try:
@ -377,7 +385,7 @@ def kill_daemon():
return True return True
def cleanup_lockfile(fd, *args): def cleanup_lockfile(pidfile, fd, *args):
try: try:
# Remove the directory entry for the lock file # Remove the directory entry for the lock file
os.unlink(pidfile) os.unlink(pidfile)
@ -390,10 +398,7 @@ def cleanup_lockfile(fd, *args):
raise SystemExit(1) raise SystemExit(1)
def lockpidfile(): def lockpidfile(pidfile):
import fcntl
import atexit
import stat
fd = os.open( fd = os.open(
pidfile, pidfile,
os.O_WRONLY | os.O_CREAT, os.O_WRONLY | os.O_CREAT,
@ -408,24 +413,25 @@ def lockpidfile():
os.ftruncate(fd, 0) os.ftruncate(fd, 0)
os.write(fd, ('%d' % os.getpid()).encode('ascii')) os.write(fd, ('%d' % os.getpid()).encode('ascii'))
os.fsync(fd) os.fsync(fd)
cleanup = partial(cleanup_lockfile, fd) cleanup = partial(cleanup_lockfile, pidfile, fd)
signal(SIGTERM, cleanup) signal(SIGTERM, cleanup)
atexit.register(cleanup) atexit.register(cleanup)
return fd return fd
def main(): def main():
global address
global pidfile
parser = get_daemon_argparser() parser = get_daemon_argparser()
args = parser.parse_args() args = parser.parse_args()
is_daemon = False
address = None
pidfile = None
if args.socket: if args.socket:
address = args.socket address = args.socket
if not use_filesystem: if not USE_FILESYSTEM:
address = '\0' + address address = '\0' + address
else: else:
if use_filesystem: if USE_FILESYSTEM:
address = '/tmp/powerline-ipc-%d' address = '/tmp/powerline-ipc-%d'
else: else:
# Use the abstract namespace for sockets rather than the filesystem # Use the abstract namespace for sockets rather than the filesystem
@ -434,13 +440,13 @@ def main():
address = address % os.getuid() address = address % os.getuid()
if use_filesystem: if USE_FILESYSTEM:
pidfile = address + '.pid' pidfile = address + '.pid'
if args.kill: if args.kill:
if args.foreground or args.replace: if args.foreground or args.replace:
parser.error('--kill and --foreground/--replace cannot be used together') parser.error('--kill and --foreground/--replace cannot be used together')
if kill_daemon(): if kill_daemon(address):
if not args.quiet: if not args.quiet:
print ('Kill command sent to daemon, if it does not die in a couple of seconds use kill to kill it') print ('Kill command sent to daemon, if it does not die in a couple of seconds use kill to kill it')
raise SystemExit(0) raise SystemExit(0)
@ -450,19 +456,19 @@ def main():
raise SystemExit(1) raise SystemExit(1)
if args.replace: if args.replace:
while kill_daemon(): while kill_daemon(address):
if not args.quiet: if not args.quiet:
print ('Kill command sent to daemon, waiting for daemon to exit, press Ctrl-C to terminate wait and exit') print ('Kill command sent to daemon, waiting for daemon to exit, press Ctrl-C to terminate wait and exit')
sleep(2) sleep(2)
if use_filesystem and not args.foreground: if USE_FILESYSTEM and not args.foreground:
# We must daemonize before creating the locked pidfile, unfortunately, # We must daemonize before creating the locked pidfile, unfortunately,
# this means further print statements are discarded # this means further print statements are discarded
daemonize() is_daemon = daemonize()
if use_filesystem: if USE_FILESYSTEM:
# Create a locked pid file containing the daemons PID # Create a locked pid file containing the daemons PID
if lockpidfile() is None: if lockpidfile(pidfile) is None:
if not args.quiet: if not args.quiet:
sys.stderr.write( sys.stderr.write(
'The daemon is already running. Use %s -k to kill it.\n' % ( 'The daemon is already running. Use %s -k to kill it.\n' % (
@ -470,7 +476,7 @@ def main():
raise SystemExit(1) raise SystemExit(1)
# Bind to address or bail if we cannot bind # Bind to address or bail if we cannot bind
sock = check_existing() sock = check_existing(address)
if sock is None: if sock is None:
if not args.quiet: if not args.quiet:
sys.stderr.write( sys.stderr.write(
@ -478,14 +484,11 @@ def main():
os.path.basename(sys.argv[0]))) os.path.basename(sys.argv[0])))
raise SystemExit(1) raise SystemExit(1)
if args.foreground: if not USE_FILESYSTEM and not args.foreground:
return main_loop(sock)
if not use_filesystem:
# We daemonize on linux # We daemonize on linux
daemonize() is_daemon = daemonize()
main_loop(sock) return main_loop(sock, is_daemon)
if __name__ == '__main__': if __name__ == '__main__':