#!/usr/bin/env python3
### BEGIN INIT INFO
# Provides: nicos-system
# Required-Start: $local_fs $remote_fs $network $named $time
# Should-Start: tango
# Required-Stop: $local_fs $remote_fs $network
# Default-Start: 2 3 4 5
# Default-Stop: 0 1 6
# Description: Network Integrated Control System init script
### END INIT INFO
# *****************************************************************************
# NICOS, the Networked Instrument Control System of the MLZ
# Copyright (c) 2009-2025 by the NICOS contributors (see AUTHORS)
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
# Module authors:
#   Georg Brandl <g.brandl@fz-juelich.de>
#   Christian Felder <c.felder@fz-juelich.de>
#
# *****************************************************************************

import os
import signal
import socket
import sys
import time
from os import path

import psutil

_DEFAULT = path.join(os.sep, 'etc', 'default', 'nicos-system')
nicos_root = path.dirname(path.dirname(path.realpath(__file__)))
sys.path.insert(0, nicos_root)


def printerr(*args, **kwargs):
    kwargs.pop('file', None)
    print(*args, file=sys.stderr, flush=True, **kwargs)


def printout(*args, **kwargs):
    kwargs.pop('file', None)
    print(*args, file=sys.stdout, flush=True, **kwargs)


if path.isfile(_DEFAULT) and 'INSTRUMENT' not in os.environ:
    try:
        with open(_DEFAULT, 'r', encoding='utf-8') as fd:
            for line in fd:
                if not line.startswith('#'):
                    (key, sep, value) = line.partition('=')
                    if sep:
                        key = key.replace('export', '').strip()
                        os.environ[key] = value.rstrip()
    except OSError as e:
        printerr('  ERROR:', e)
        printerr('WARNING: Ignoring defaults in %s.' % _DEFAULT)

# We need to read the nicos.conf file, so let nicos do that.
from nicos import config  # isort:skip

config.apply()


def get_config():
    hostname = ''
    try:
        hostname = socket.getfqdn().split('.')[0]
    except OSError as e:
        printerr('Could not figure out host name (%s).' % e)
        printerr('Continue with nonspecific services.')

    host_spec_services = 'services_%s' % hostname
    if hostname and hasattr(config, host_spec_services):
        services = getattr(config, host_spec_services)
    else:
        services = config.services
    log_path = path.join(nicos_root, config.logging_path)
    pid_path = path.join(nicos_root, config.pid_path)
    return services, log_path, pid_path


SERVICES, LOG_PATH, PID_PATH = get_config()


def usage(prog):
    printerr('Usage: %s start|stop|restart|status [service ...]' % prog)
    printerr('Possible services are', ', '.join(SERVICES))
    return 1


# compatibility across multiple versions of psutil

def cmd_line(p):
    return p.cmdline if isinstance(p.cmdline, list) else p.cmdline()


def process_name(p):
    return p.name if isinstance(p.name, str) else p.name()


def process_children(p):
    try:
        return p.get_children(recursive=True)
    except TypeError:  # always recursive in psutil < 1.0
        return p.get_children()
    except AttributeError:  # not in psutil >= 3.0
        return p.children(recursive=True)


def read_pidfile(name, wait=0):
    begin = time.time()
    pidpath = path.join(PID_PATH, '%s.pid' % name)
    while True:
        try:
            with open(pidpath, 'rb') as f:
                pid = f.read()
                pid = int(pid)
        except (OSError, ValueError):
            if time.time() > begin + wait:
                return
        else:
            break
        time.sleep(0.1)
    # check that pid really exists
    try:
        params = cmd_line(psutil.Process(pid))
        if '-' in name:
            procname, instname = name.split('-')
            if not params[1].endswith(procname) and instname in params[2:]:
                raise ValueError
        elif not params[1].endswith(name):   # not our process?
            raise ValueError
        return pid
    except Exception:
        try:
            os.unlink(pidpath)
        except Exception:
            pass
        return


def check_stray_procs(name):
    known_process = read_pidfile(name)
    known_subprocs = set()
    if known_process:
        known_subprocs.add(known_process)
        for p in process_children(psutil.Process(known_process)):
            known_subprocs.add(p.pid)

    still_running = set()
    for p in psutil.process_iter():
        if not p.is_running():
            continue
        if p.pid in known_subprocs:
            continue
        if 'python' in process_name(p) and 'nicos-%s' % name in ' '.join(cmd_line(p)):
            still_running.add(p)

    if still_running:
        printerr('WARNING: some', name, 'processes are still running!')
        printerr('  PID  Command line')
        for p in still_running:
            printerr('%5d  %s' % (p.pid, ' '.join(cmd_line(p))))
        printerr()
        printerr('Use one of the following commands to kill them:')
        pids = ' '.join(str(p.pid) for p in still_running)
        printerr('  kill %s' % pids)
        printerr('  kill -9 %s' % pids)
        return True
    return False


def start_daemon(name):
    # check for processes we would not see
    if check_stray_procs(name):
        printerr()
        printerr('NOT starting any additional daemons until this is resolved.')
        printerr()
        return
    printout('Starting %s...' % name, end=' ')
    previous_pid = read_pidfile(name)
    if previous_pid:
        printout('already running? (pid = %d)' % previous_pid)
        return
    try:
        procname, instname = name.split('-')
    except ValueError:
        procname, instname = name, None
    executable = path.join(nicos_root, 'bin', 'nicos-%s' % procname)
    os.system('%s %s%s -d' % (sys.executable, executable, ' -S %s' % (name,) if instname
                           else '', ))
    pid = read_pidfile(name, wait=10)
    if not pid:
        printout('failed, please look in the logfile at %s'
                 % path.join(LOG_PATH, name, name +
                             time.strftime('-%Y-%m-%d.log')))
        return False
    printout('pid =', pid)
    if name.startswith('cache'):
        time.sleep(3)
    return True


def kill_daemon(name):
    try:
        printout('Killing %s...' % name, end=' ')
        pid = read_pidfile(name)
        if not pid:
            printout('not running?')
            return True
        printout('pid =', pid, end=' ')
        os.kill(pid, signal.SIGTERM)
        begin = time.time()
        while 1:
            time.sleep(0.1)
            pid = read_pidfile(name)
            if not pid:
                printout()
                return True
            current = time.time()
            if current > begin + 10:
                printout('FAILED to stop!')
                return False
            elif current > begin + 6:
                printout('KILL', end=' ')
                try:
                    os.kill(pid, signal.SIGKILL)
                except ProcessLookupError:
                    pass  # already terminated
                except OSError as err:
                    printerr('FAILED to kill:', err)
                    return False
            elif current > begin + 1:
                time.sleep(0.2)
            printout('.', end=' ')
        printout()
    finally:
        # no matching process should have survived. If they do, warn!
        check_stray_procs(name)


def main(args):
    try:
        action = args[1]
        if action not in ('start', 'stop', 'restart', 'status'):
            raise IndexError
    except IndexError:
        return usage(args[0])

    all_option = False
    explicit_daemons = False

    daemon_args = args[2:]
    if '-a' in daemon_args:
        all_option = True
        daemon_args.remove('-a')
    if daemon_args:
        explicit_daemons = True
        daemons = daemon_args
        for daemon in daemons:
            if daemon not in SERVICES:
                return usage(args[0])
    else:
        daemons = SERVICES

    exitstatus = 0

    if action == 'start':
        for name in daemons:
            if not start_daemon(name):
                exitstatus = 1

    elif action == 'stop':
        for name in reversed(daemons):
            if name.startswith('cache') and not all_option and not explicit_daemons:
                printout('Not stopping %s, use -a option to force' % name)
                continue
            if not kill_daemon(name):
                exitstatus = 1

    elif action == 'restart':
        to_restart = []
        for name in daemons:
            if name.startswith('cache') and not all_option and not explicit_daemons:
                printout('Not restarting %s, use -a option to force' % name)
                continue
            to_restart.append(name)
        dmns = ' '.join(to_restart)
        os.system('%s stop %s && sleep 1 && %s start %s' % (
            args[0], dmns, args[0], dmns))

    elif action == 'status':
        for name in daemons:
            try:
                pid = read_pidfile(name)
            except Exception as e:
                printerr('%-12s: could not read pidfile: %s' % (name, e))
            if pid:
                printout('%-12s: running (pid = %s)' % (name, pid))
            else:
                printout('%-12s: dead' % name)
                exitstatus = 1
            # ignoring return value here.
            check_stray_procs(name)

    return exitstatus


try:
    exitstatus = main(sys.argv)
except BaseException as e:
    printerr('ERROR:', e.__class__.__name__, '-', e)
else:
    sys.exit(exitstatus)
