#!/usr/bin/env python3

# Copyright 2013-2021 Marcus Furlong <furlongm@gmail.com>
#
# This file is part of Patchman.
#
# Patchman is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 only.
#
# Patchman is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Patchman. If not, see <http://www.gnu.org/licenses/>


import argparse
import os
import sys

from django import setup as django_setup
from django.core.exceptions import MultipleObjectsReturned

os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'patchman.settings')
from django.conf import settings  # noqa

django_setup()

from arch.utils import clean_architectures
from errata.tasks import update_errata
from errata.utils import (
    enrich_errata, mark_errata_security_updates,
    scan_package_updates_for_affected_packages,
)
from hosts.models import Host
from hosts.utils import clean_tags, find_host_updates_homogenous
from modules.utils import clean_modules
from packages.utils import (
    clean_packagenames, clean_packages, clean_packageupdates,
)
from reports.models import Report
from reports.tasks import remove_reports_with_no_hosts
from repos.models import Repository
from repos.utils import clean_repos
from security.utils import update_cves, update_cwes
from util import get_setting_of_type
from util.logging import info_message, set_quiet_mode


def get_host(host=None, action='Performing action'):
    """ Helper function to get a single host object
    """
    host_obj = None
    hostdot = host + '.'
    text = f'{action} for Host {host}'

    try:
        host_obj = Host.objects.get(hostname__startswith=hostdot)
    except Host.DoesNotExist:
        try:
            host_obj = Host.objects.get(hostname__startswith=host)
        except Host.DoesNotExist:
            text = f'Host {host} does not exist'
    except MultipleObjectsReturned:
        matches = Host.objects.filter(hostname__startswith=host).count()
        text = f'{matches} Hosts match hostname "{host}"'

    info_message(text=text)
    return host_obj


def get_hosts(hosts=None, action='Performing action'):
    """ Helper function to get a list of hosts
    """
    host_objs = []
    if hosts:
        if isinstance(hosts, str):
            host_obj = get_host(hosts, action)
            if host_obj is not None:
                host_objs.append(host_obj)
        elif isinstance(hosts, list):
            for host in hosts:
                host_obj = get_host(host, action)
                if host_obj is not None:
                    host_objs.append(host_obj)
    else:
        text = f'{action} for all Hosts\n'
        info_message(text=text)
        host_objs = Host.objects.all()

    return host_objs


def get_repos(repo=None, action='Performing action', only_enabled=False):
    """ Helper function to get a list of repos
    """
    repos = []
    if repo:
        try:
            repos.append(Repository.objects.get(id=repo))
            text = f'{action} for Repo {repo}'
        except Repository.DoesNotExist:
            text = f'Repo {repo} does not exist'
    else:
        text = f'{action} for all Repos\n'
        if only_enabled:
            repos = Repository.objects.filter(enabled=True)
        else:
            repos = Repository.objects.all()

    info_message(text=text)
    return repos


def refresh_repos(repo=None, force=False):
    """ Refresh metadata for all enabled repos.
        Specify a repo ID to update a single repo.
    """
    repos = get_repos(repo, 'Refreshing metadata', True)
    for repo in repos:
        text = f'Repository {repo.id} : {repo}'
        info_message(text=text)
        repo.refresh(force)
        info_message(text='')


def list_repos(repos=None):
    """ Print info about a list of repositories
        Defaults to all repos
    """
    matching_repos = get_repos(repos, 'Printing information')
    for repo in matching_repos:
        repo.show()


def list_hosts(hosts=None):
    """ Print info about a list of hosts
        Defaults to all hosts
    """
    matching_hosts = get_hosts(hosts, 'Printing information')
    for host in matching_hosts:
        host.show()


def clean_reports(hoststr=None):
    """ Delete old reports for all hosts, specify host for a single host.
        Reports with non existent hosts are only removed when no host is
        specified.
    """
    hosts = get_hosts(hoststr, 'Cleaning Reports')
    for host in hosts:
        host.clean_reports()

    if not hoststr:
        remove_reports_with_no_hosts()


def host_updates_alt(host=None):
    """ Find updates for all hosts, specify host for a single host
    """
    hosts = get_hosts(host, 'Finding updates')
    find_host_updates_homogenous(hosts, verbose=True)


def host_updates(host=None):
    """ Find updates for all hosts, specify host for a single host
    """
    hosts = get_hosts(host, 'Finding updates')
    for host in hosts:
        info_message(text=str(host))
        host.find_updates()
        info_message(text='')


def diff_hosts(hosts):
    """ Display the differences between two hosts
    """
    hosts_to_compare = get_hosts(hosts, 'Retrieving info')

    if len(hosts_to_compare) != 2:
        sys.exit(1)

    hostA = hosts_to_compare[0]
    hostB = hosts_to_compare[1]
    packagesA = set(hostA.packages.all())
    packagesB = set(hostB.packages.all())
    reposA = set(hostA.repos.all())
    reposB = set(hostB.repos.all())

    package_diff_AB = packagesA.difference(packagesB)
    package_diff_BA = packagesB.difference(packagesA)
    repo_diff_AB = reposA.difference(reposB)
    repo_diff_BA = reposB.difference(reposA)

    info_message(text=f'+ {hostA.hostname}')
    info_message(text=f'- {hostB.hostname}')

    if hostA.os != hostB.os:
        info_message(text='\nOperating Systems')
        info_message(text=f'+ {hostA.os}')
        info_message(text=f'- {hostB.os}')
    else:
        info_message(text='\nNo OS differences')

    if hostA.arch != hostB.arch:
        info_message(text='\nArchitecture')
        info_message(text=f'+ {hostA.arch}')
        info_message(text=f'- {hostB.arch}')
    else:
        info_message(text='\nNo Architecture differences')

    if hostA.kernel != hostB.kernel:
        info_message(text='\nKernels')
        info_message(text=f'+ {hostA.kernel}')
        info_message(text=f'- {hostB.kernel}')
    else:
        info_message(text='\nNo Kernel differences')

    if len(package_diff_AB) != 0 or len(package_diff_BA) != 0:
        info_message(text='\nPackages')
        for package in package_diff_AB:
            info_message(text=f'+ {package}')
        for package in package_diff_BA:
            info_message(text=f'- {package}')
    else:
        info_message(text='\nNo Package differences')

    if len(repo_diff_AB) != 0 or len(repo_diff_BA) != 0:
        info_message(text='\nRepositories')
        for repo in repo_diff_AB:
            info_message(text=f'+ {repo}')
        for repo in repo_diff_BA:
            info_message(text=f'- {repo}')
    else:
        info_message(text='\nNo Repo differences')


def delete_hosts(hosts=None):
    """ Delete a host or matching pattern of hosts
    """
    if hosts:
        matching_hosts = get_hosts(hosts)
        for host in matching_hosts:
            text = f'Deleting host: {host.hostname}:'
            info_message(text=text)
            host.delete()


def toggle_host_hro(hosts=None, host_repos_only=True):
    """ Toggle host_repos_only for a host or matching pattern of hosts
    """
    if host_repos_only:
        toggle = 'Setting'
    else:
        toggle = 'Unsetting'
    if hosts:
        matching_hosts = get_hosts(hosts, f'{toggle} host_repos_only')
        for host in matching_hosts:
            info_message(text=str(host))
            host.host_repos_only = host_repos_only
            host.save()


def toggle_host_check_dns(hosts=None, check_dns=True):
    """ Toggle check_dns for a host or matching pattern of hosts
    """
    if check_dns:
        toggle = 'Setting'
    else:
        toggle = 'Unsetting'
    if hosts:
        matching_hosts = get_hosts(hosts, f'{toggle} check_dns')
        for host in matching_hosts:
            info_message(text=str(host))
            host.check_dns = check_dns
            host.save()


def dns_checks(host=None):
    """ Check all hosts for reverse DNS mismatches, specify host for a single
        host
    """
    hosts = get_hosts(host, 'Checking rDNS')
    for host in hosts:
        host.check_rdns()


def process_reports(host=None, force=False):
    """ Process all pending reports, specify host to process only a single host
        The --force option forces even processed reports to be reprocessed
        No reports are skipped in case some reports contain repo information
        and others only contain package information.
    """
    reports = []
    if host:
        try:
            reports = Report.objects.filter(
                processed=force, host=host).order_by('created')
            text = f'Processing Reports for Host {host}'
        except Report.DoesNotExist:
            text = f'No Reports exist for Host {host}'
    else:
        text = 'Processing Reports for all Hosts'
        reports = Report.objects.filter(processed=force).order_by('created')

    info_message(text=text)

    for report in reports.iterator():
        report.process(find_updates=False)


def dbcheck(remove_duplicates=False):
    """ Runs all clean_* functions to check database consistency
    """
    clean_packageupdates()
    clean_packages(remove_duplicates)
    clean_packagenames()
    clean_architectures()
    clean_repos()
    clean_modules()
    clean_packageupdates()
    clean_tags()


def collect_args():
    """ Collect argparse arguments
    """
    parser = argparse.ArgumentParser(description='Patchman CLI tool')
    parser.add_argument(
        '-f', '--force', action='store_true',
        help='Ignore stored checksums and force-refresh all Mirrors')
    parser.add_argument(
        '-q', '--quiet', action='store_true',
        help='Quiet mode (e.g. for cronjobs)')
    parser.add_argument(
        '-r', '--refresh-repos', action='store_true',
        help='Refresh Repositories')
    parser.add_argument(
        '-R', '--repo',
        help='Only perform action on a specific Repository (repo_id)')
    parser.add_argument(
        '-lr', '--list-repos', action='store_true',
        help='List all Repositories')
    parser.add_argument(
        '-lh', '--list-hosts', action='store_true',
        help='List all Hosts')
    parser.add_argument(
        '-dh', '--delete-hosts', action='store_true',
        help='Delete hosts, requires -H, matches substring patterns')
    parser.add_argument(
        '-u', '--host-updates', action='store_true',
        help='Find Host updates')
    parser.add_argument(
        '-A', '--host-updates-alt', action='store_true',
        help='Find Host updates (alternative algorithm that may be faster \
        when there are many homogeneous hosts)')
    hro_group = parser.add_mutually_exclusive_group()
    hro_group.add_argument(
        '-shro', '--set-host-repos-only', action='store_true',
        help='Set host_repos_only, requires -H, matches substring patterns')
    hro_group.add_argument(
        '-uhro', '--unset-host-repos-only', action='store_true',
        help='Unset host_repos_only, requires -H, matches substring patterns')
    dns_group = parser.add_mutually_exclusive_group()
    dns_group.add_argument(
        '-sdns', '--set-check-dns', action='store_true',
        help='Set check_dns, requires -H, matches substring patterns')
    dns_group.add_argument(
        '-udns', '--unset-check-dns', action='store_true',
        help='Unset check_dns, requires -H, matches substring patterns')
    parser.add_argument(
        '-H', '--host',
        help='Only perform action on a specific Host (fqdn)')
    parser.add_argument(
        '-p', '--process-reports', action='store_true',
        help='Process pending Reports')
    parser.add_argument(
        '-c', '--clean-reports', action='store_true',
        help='Remove all but the last three Reports')
    parser.add_argument(
        '-d', '--dbcheck', action='store_true',
        help='Perform some sanity checks and clean unused db entries')
    parser.add_argument(
        '-rd', '--remove-duplicates', action='store_true',
        help='Remove duplicates during dbcheck - this may take some time')
    parser.add_argument(
        '-n', '--dns-checks', action='store_true',
        help='Perform reverse DNS checks if enabled for that Host')
    parser.add_argument(
        '-a', '--all', action='store_true',
        help='Convenience flag for -r -A -p -c -d -n -e')
    parser.add_argument(
        '-D', '--diff', metavar=('hostA', 'hostB'), nargs=2,
        help='Show differences between two Hosts in diff-like output')
    parser.add_argument(
        '-e', '--update-errata', action='store_true',
        help='Update Errata')
    parser.add_argument(
        '-E', '--erratum-type',
        help='Only update the specified Erratum type (e.g. `yum`, `ubuntu`, `arch`)')
    parser.add_argument(
        '-v', '--update-cves', action='store_true',
        help='Update CVEs from https://cve.org')
    parser.add_argument(
        '--cve', help="Only update the specified CVE (e.g. CVE-2024-1234)")
    parser.add_argument(
        '--fetch-nist-data', '-nd', action='store_true',
        help='Fetch NIST CVE data in addition to MITRE data (rate-limited to 1 API call every 6 seconds)'
    )
    return parser


def process_args(args):
    """ Process command line arguments
    """

    showhelp = True
    recheck = False

    if args.all:
        args.process_reports = True
        args.clean_reports = True
        args.refresh_repos = True
        args.host_updates_alt = True
        args.clean_updates = True
        args.dbcheck = True
        args.dns_checks = True
        args.errata = True
    if args.list_repos:
        list_repos(args.repo)
        return False
    if args.list_hosts:
        list_hosts(args.host)
        return False
    if args.delete_hosts:
        delete_hosts(args.host)
        showhelp = False
        return False
    if args.set_host_repos_only:
        toggle_host_hro(args.host)
        showhelp = False
        return False
    elif args.unset_host_repos_only:
        toggle_host_hro(args.host, False)
        showhelp = False
        return False
    if args.set_check_dns:
        toggle_host_check_dns(args.host)
        showhelp = False
        return False
    elif args.unset_check_dns:
        toggle_host_check_dns(args.host, False)
        showhelp = False
        return False
    if args.diff:
        diff_hosts(args.diff)
        return False
    if args.clean_reports:
        clean_reports(args.host)
        showhelp = False
    if args.process_reports:
        process_reports(args.host, args.force)
        showhelp = False
    if args.dbcheck:
        dbcheck(args.remove_duplicates)
        showhelp = False
    if args.refresh_repos:
        refresh_repos(args.repo, args.force)
        showhelp = False
        recheck = True
    if args.host_updates:
        host_updates(args.host)
        showhelp = False
        recheck = True
    if args.host_updates_alt:
        host_updates_alt(args.host)
        showhelp = False
        recheck = True
    if args.dns_checks:
        dns_checks(args.host)
        showhelp = False
    if args.update_errata:
        concurrent = get_setting_of_type(
            setting_name='CONCURRENT_PROCESSING',
            setting_type=bool,
            default=True,
        )
        max_workers = get_setting_of_type(
            setting_name='CONCURRENT_WORKERS',
            setting_type=int,
            default=25,
        )
        update_errata(args.erratum_type, args.force, args.repo)
        scan_package_updates_for_affected_packages()
        mark_errata_security_updates(concurrent, max_workers)
        enrich_errata(concurrent, max_workers)
        showhelp = False
    if args.update_cves:
        update_cves(args.cve, args.fetch_nist_data)
        update_cwes(args.cve)
        showhelp = False
    if args.dbcheck and recheck:
        dbcheck(args.remove_duplicates)
    return showhelp


def main():

    parser = collect_args()
    args = parser.parse_args()
    set_quiet_mode(args.quiet)
    showhelp = process_args(args)
    if showhelp:
        parser.print_help()


if __name__ == '__main__':
    main()
