#!python
#
# Sniffles2
# A fast structural variant caller for long-read sequencing data
#
# Created:     28.05.2021
# Author:      Moritz Smolka
# Maintainer:  Hermann Romanek
# Contact:     sniffles@romanek.at
#
import logging
import logging.config
import multiprocessing
from collections import deque
from typing import Optional

from sniffles.utils.resmon import ResourceMonitor

import sys

if not sys.version_info >= (3, 10):
    print(f"Error: Sniffles2 must be run with Python version 3.10 or above (detected Python version: "
          f"{sys.version_info.major}.{sys.version_info.minor}). Exiting")
    exit(1)

import math
import time
import os

import pysam

from sniffles.config import SnifflesConfig
from sniffles import vcf
from sniffles import snf
from sniffles import parallel
from sniffles import util

DEV_MONITOR_MEM = False

if DEV_MONITOR_MEM:
    try:
        import psutil
    except ImportError:
        logging.getLogger('sniffles.memory').warning('psutil not available.')
        psutil = None

        def dbg_get_total_memory_usage_MB():
            pass
    else:
        logging.getLogger('sniffles.memory').info('Watching memory')


        def dbg_get_total_memory_usage_MB():
            total = 0
            n = 0
            proc = psutil.Process(os.getpid())
            for child in proc.children(recursive=True):
                total += child.memory_info().rss
                n += 1
            total += proc.memory_info().rss
            return total / (1000.0 * 1000.0)


def Sniffles2_Main(processes: list[parallel.SnifflesWorker]):
    #
    # Determine Sniffles2 run mode
    #
    config = SnifflesConfig()

    if config.dev_debug:
        try:
            import pydevd_pycharm  # noqa
        except ImportError:
            logging.getLogger('sniffles.dev').info(f'pydevd not available!')
        else:
            try:
                pydevd_pycharm.settrace('localhost', port=config.dev_debug, stdoutToServer=True, stderrToServer=True)
            except:  # noqa
                logging.getLogger('sniffles.dev').info(f'Failed to connect to PyCharm debugger.', exc_info=True)
            else:
                logging.getLogger('sniffles.dev').info(f'Connected to PyCharm debugger.')

    if config.no_progress:
        logging.getLogger('sniffles.progress').setLevel(logging.CRITICAL)
        logging.getLogger('sniffles.worker').setLevel(logging.CRITICAL)

    input_ext = [f.split(".")[-1].lower() for f in config.input]

    # needed for running on osx or Python 3.14+ on Linux
    if sys.platform in ("darwin", "linux"):
        multiprocessing.set_start_method("fork")

    if len(set(input_ext)) > 1:
        util.fatal_error_main(f"Please specify either: A single .bam/.cram file - OR - one or more .snf files - OR - "
                              f"a single .tsv file containing a list of .snf files and optional sample ids as input. "
                              f"(supplied were: {list(set(input_ext))})")

    if "bam" in input_ext or "cram" in input_ext:
        if input_ext.count("bam") + input_ext.count("cram") > 1:
            util.fatal_error_main(f"Please specify max 1 .bam//.cram file as input (got {input_ext.count('bam')})")
        config.input = config.input[0]

        if config.genotype_vcf is not None:
            config.mode = "genotype_vcf"
        else:
            config.mode = "call_sample"

        config.input_is_cram = False
        if "bam" in input_ext:
            config.input_mode = r"rb"
        elif "cram" in input_ext:
            config.input_mode = r"rc"
            config.input_is_cram = True

    elif "snf" in input_ext or "tsv" in input_ext:
        config.mode = "combine"
    else:
        util.fatal_error_main(f"Failed to determine run mode from input. Please specify either: A single .bam file "
                              f"- OR - one or more .snf files - OR - a single .tsv file containing a list of .snf "
                              f"files and optional sample ids as input. (supplied were: {list(set(input_ext))})")

    if config.mode != "call_sample" and config.snf is not None:
        util.fatal_error_main(f"--snf cannot be used with run mode {config.mode}")

    if config.vcf is None and config.snf is None:
        util.fatal_error_main("Please specify at least one of: --vcf or --snf for output "
                              "(both may be used at the same time)")

    log = logging.getLogger('sniffles.main')
    if config.dev_debug_log:
        logging.getLogger().setLevel(logging.DEBUG)
    if config.dev_progress_log:
        logging.getLogger('sniffles.progress').setLevel(logging.INFO)

    if config.mode == "call_sample":
        if config.sample_id is None:
            # config.sample_id,_=os.path.splitext(os.path.basename(config.input))
            config.sample_ids_vcf = [(0, "SAMPLE")]
        else:
            config.sample_ids_vcf = [(0, config.sample_id)]

    elif config.mode == "combine":
        config.sample_id = None

        if config.combine_consensus:
            config.sample_ids_vcf = [(0, "CONSENSUS")]
        else:
            config.sample_ids_vcf = []  # Determined from .snf headers

    log.info(f"Running {config.version}, build {config.build}")
    log.info(f"  Run Mode: {config.mode}")
    log.info(f"  Start on: {config.start_date}")
    log.info(f"  Working dir: {config.workdir}")
    log.info(f"  Used command: {config.command}")
    log.info("==============================")

    rkwargs = {}  # result kwargs
    bam_in = None

    monitor = ResourceMonitor(config)
    if monitor and monitor.filename is not None:
        logging.getLogger('sniffles.resources').info(f'Logging memory usage to {monitor.filename}')

    #
    # call_sample/genotype_vcf: Open .bam file for single calling / .snf creation
    #
    contig_tandem_repeats = {}
    if config.mode == "call_sample" or config.mode == "genotype_vcf":
        log.info(f"Opening for reading: {config.input}")
        bam_in = pysam.AlignmentFile(config.input, config.input_mode)
        try:
            has_index = bam_in.check_index()
            if not has_index:
                raise ValueError
        except ValueError:
            util.fatal_error_main(f"Unable to load index for input file '{config.input}'. Please verify that your "
                                  f"input file is sorted + indexed and that the index .bai file is valid and in the "
                                  f"right location.")

        #
        # Load tandem repeat annotations
        #
        if config.tandem_repeats is not None:
            contig_tandem_repeats = util.load_tandem_repeats(config.tandem_repeats, config.tandem_repeat_region_pad)
            log.info(f"Opening for reading: {config.tandem_repeats} (tandem repeat annotations for "
                     f"{len(contig_tandem_repeats)} contigs)")

    #
    # genotype_vcf: Read SVs from VCF to be genotyped
    #
    if config.mode == "genotype_vcf":
        path, ext = os.path.splitext(config.genotype_vcf)
        ext = ext.lower()
        vcf_in_handle = None
        if ext == ".gz":
            vcf_in_handle = pysam.VariantFile(config.genotype_vcf, "rb")
        elif ext == ".vcf":
            vcf_in_handle = open(config.genotype_vcf, "r")
        else:
            util.fatal_error_main("Expected a .vcf or .vcf.gz file for genotyping using --genotype-vcf")
        vcf_in = vcf.VCF(config, vcf_in_handle)

        genotype_lineindex_order = []
        genotype_lineindex_svs = {}
        genotype_contig_svs = {}
        for svcall in vcf_in.read_svs_iter():
            if svcall.contig not in genotype_contig_svs:
                genotype_contig_svs[svcall.contig] = []
            assert (svcall.raw_vcf_line_index not in genotype_lineindex_svs)
            genotype_lineindex_order.append(svcall.raw_vcf_line_index)
            genotype_lineindex_svs[svcall.raw_vcf_line_index] = svcall
            genotype_contig_svs[svcall.contig].append(svcall)
        rkwargs['genotype_lineindex_order'] = genotype_lineindex_order
        log.info(f"Opening for reading: {config.genotype_vcf} (read {len(genotype_lineindex_svs)} SVs to be genotyped)")

    #
    # Open output files
    #
    vcf_out = None
    if config.vcf is not None:

        vcf_output_info = []
        if config.mode == "combine":
            vcf_output_info.append("multi-sample")
        else:
            vcf_output_info.append("single-sample")
        if config.sort:
            vcf_output_info.append("sorted")
        if config.vcf_output_bgz:
            vcf_output_info.append("bgzipped")
            vcf_output_info.append("tabix-indexed")

        if len(vcf_output_info) == 0:
            vcf_output_info_str = ""
        else:
            vcf_output_info_str = f"({', '.join(vcf_output_info)})"

        if os.path.exists(config.vcf) and not config.allow_overwrite:
            util.fatal_error_main(f"Output file '{config.vcf}' already exists! Use --allow-overwrite to ignore "
                                  f"this check and overwrite.")

        if config.vcf_output_bgz:
            if not config.sort:
                util.fatal_error_main(".gz (bgzip) output is only supported with sorting enabled")

        parent_dir = os.path.dirname(os.path.abspath(config.uncompressed_vcf_name))
        if not os.path.exists(parent_dir):
            util.fatal_error_main(f"Directory {parent_dir} does not exists.")

        vcf_handle = open(config.uncompressed_vcf_name, "w")
        vcf_out = vcf.VCF(config, vcf_handle)

        if config.mode == "call_sample" or config.mode == "combine":
            if config.reference is not None:
                log.info(f"Opening for reading: {config.reference}")
            vcf_out.open_reference()

        log.info(f"Opening for writing: {config.vcf} {vcf_output_info_str}")

    # SNF writing during single sample mode
    snf_out = None
    if config.snf is not None:
        log.info(f"Opening for writing: {config.snf}")

        if os.path.exists(config.snf) and not config.allow_overwrite:
            util.fatal_error_main(f"Output file '{config.snf}' already exists! Use --allow-overwrite to ignore "
                                  f"this check and overwrite.")
        else:
            snf_out = snf.SNFile(config, open(config.snf, "wb"))

    psnf_out = None
    if psnf_name := config.dev_population_snf:
        log.info(f'Creating population SNF: {psnf_name}')

        if os.path.exists(psnf_name) and not config.allow_overwrite:
            util.fatal_error_main(f'Population SNF {psnf_name} already exists!')
        from sniffles.snfp import PopulationSNF
        psnf_out = PopulationSNF(config, open(psnf_name, 'wb'))
        rkwargs['psnf_out'] = psnf_out

    #
    # Plan multiprocessing tasks
    #
    task_id = 0
    tasks = deque()
    contigs = []
    contig_tasks_intervals = {}

    if config.mode == "call_sample" or config.mode == "genotype_vcf":
        #
        # Process .bam header
        #
        task_classes = {
            'call_sample': parallel.CallTask,
            'genotype_vcf': parallel.GenotypeTask,
        }

        total_mapped = bam_in.mapped
        if (config.threads == 1 and not config.low_memory) or config.task_count_multiplier == 0:
            task_max_reads = total_mapped
        else:
            task_max_reads = max(1, math.floor(total_mapped / (config.threads * config.task_count_multiplier)))

        if total_mapped == 0:
            # Total mapped returns 0 for CRAM files
            config.task_read_id_offset_mult = 10 ** 9
        else:
            # BAM file
            config.task_read_id_offset_mult = 10 ** math.ceil(math.log(total_mapped) + 1)

        contig_lengths = []
        contigs_with_tr_annotations = 0
        for contig in bam_in.get_index_statistics():
            if task_max_reads == 0:
                task_count = 1
            else:
                task_count = max(1, math.ceil(contig.mapped / float(task_max_reads)))
            contig_str = str(contig.contig)

            contig_length = bam_in.get_reference_length(contig_str)
            if not util.should_process_contig(contig_str, contig_length, config):
                continue

            contigs.append(contig_str)
            contig_lengths.append((contig_str, contig_length))
            task_length = math.floor(contig_length / float(task_count))
            contigs_with_tr_annotations += int(contig_str in contig_tandem_repeats)
            startpos = 0

            while startpos < contig_length - 1:
                endpos = min(contig_length - 1, startpos + task_length)
                if config.genotype_vcf is not None:
                    if contig_str in genotype_contig_svs:
                        genotype_svs = [target_sv for target_sv in genotype_contig_svs[contig_str] if
                                        startpos <= target_sv.pos < endpos]
                    else:
                        genotype_svs = []
                else:
                    genotype_svs = None

                task = task_classes[config.mode](
                    id=task_id,
                    contig=contig_str,
                    start=startpos,
                    end=endpos,
                    assigned_process_id=None,
                    tandem_repeats=contig_tandem_repeats[contig_str] if contig_str in contig_tandem_repeats else None,
                    genotype_svs=genotype_svs,
                    sv_id=0,
                    config=config,
                    regions=config.regions_by_contig.get(contig_str),
                )
                tasks.append(task)
                if contig_str not in contig_tasks_intervals:
                    contig_tasks_intervals[contig_str] = []
                contig_tasks_intervals[contig_str].append((task.start, task.end, task))
                startpos += task_length
                task_id += 1
        config.contig_lengths = contig_lengths

        if contigs_with_tr_annotations < len(contig_lengths) and config.tandem_repeats != None:
            log.info(f"Info: {contigs_with_tr_annotations} of {len(contig_lengths)} contigs in the input sample have "
                     f"associated tandem repeat annotations.")

            if contigs_with_tr_annotations == 0:
                util.fatal_error_main("A tandem repeat annotations file was provided, but no matching "
                                      "annotations were found for any contig in the sample input file. "
                                      "Please check if the contig naming scheme in the tandem repeat "
                                      "annotations matches with the one in the input sample file.")

    elif config.mode == "combine":
        #
        # Process .snf headers
        #
        config.snf_input_info = []
        total_mapped = 0

        # List of filenames and optional sample label from tsv
        input_snfs_sample_ids: list[tuple[str, Optional[str]]] = []

        if len(config.input) == 1 and input_ext[0] == "tsv":
            log.info(f"Opening for reading: {config.input[0]} (loading list of .snf files and associated sample ids)")
            with open(config.input[0], "r") as tsv_handle:
                for line_index, line in enumerate(tsv_handle.readlines()):
                    line_strip = line.strip()
                    if len(line_strip) == 0 or line_strip[0] == "#":
                        continue
                    parts = line_strip.split("\t")
                    if len(parts) == 1:
                        snf_filename = parts[0]
                        sample_id = None
                    elif len(parts) == 2:
                        snf_filename = parts[0]
                        sample_id = parts[1]
                    else:
                        util.fatal_error_main(f"Invalid sample list .tsv : {config.input[0]} : "
                                              f"Line {line_index + 1} - expected either one or two columns "
                                              f"(first column: .snf filename, second column: optional sample id to "
                                              f"overrule the one specified in the .snf file)")
                    input_snfs_sample_ids.append((snf_filename, sample_id))
        elif input_ext[0] == "snf":
            input_snfs_sample_ids = [(item, None) for item in config.input]
        else:
            util.fatal_error_main("Failed to determine .snf files to be combined. Please specify either one or "
                                  "more .snf files OR a single .tsv file as input for multi-calling.")

        log.info("The following samples will be processed in multi-calling:")
        for snf_internal_id, (input_filename, sample_id) in enumerate(input_snfs_sample_ids):
            snf_in = snf.SNFile(config, open(input_filename, "rb"), filename=input_filename)
            snf_in.read_header()
            total_mapped += snf_in.header["snf_candidate_count"]
            contig_lengths = snf_in.header["config"]["contig_lengths"]
            if not config.dev_skip_snf_validation:
                if config.snf_block_size != snf_in.header["config"]["snf_block_size"]:
                    util.fatal_error_main(f"SNF block size differs for {input_filename}")
                if config.snf_format_version != snf_in.header["config"]["snf_format_version"]:
                    util.fatal_error_main(f"SNF format version for {input_filename} is not supported")
            if sample_id is None:
                if snf_in.header["config"]["sample_id"] is not None:
                    sample_id = snf_in.header["config"]["sample_id"]
                else:
                    sample_id, _ = os.path.splitext(os.path.basename(input_filename))
            config.snf_input_info.append({"internal_id": snf_internal_id, "sample_id": sample_id,
                                          "filename": input_filename})
            reqc_info = f' (Rerunning QC)' if snf_in.reqc else ''
            snf_in.close()
            log.info(f"    {input_filename} (sample ID in output VCF='{sample_id}'{reqc_info})")

        if not config.combine_consensus:
            for info in config.snf_input_info:
                config.sample_ids_vcf.append((info["internal_id"], info["sample_id"]))

        if snfp_filename := config.combine_population:
            from sniffles.snfp import PopulationSNF
            snfp = PopulationSNF.open(snfp_filename)

        # TODO: Assure header consistency across multiple .snfs
        if to_process := (config.contig or config.regions_by_contig):
            contig_lengths = [(name, length) for name, length in contig_lengths if name in to_process]

        result_class = None
        if len(input_snfs_sample_ids) > config.combine_max_inmemory_results:
            log.info(f'Using tmp file aggregation for merge.')
            from sniffles.result import CombineResultTmpFile
            result_class = CombineResultTmpFile

            if config.sort:
                log.warning(
                    f'Sniffles currently does not support sorting on a number of input files exceeding'
                    f'--combine-max-inmemory-results (currently set to {config.combine_max_inmemory_results}).' 
                    f'Unsorted calls will be written to result-{config.run_id}-*-unsorted.part.vcf. '
                    f'Result will be unsorted and uncompressed. To avoid this message please use either --no-sort ' 
                    f'or increase the number of in-memory results.'
                )
                if config.vcf_output_bgz:
                    config.vcf = config.uncompressed_vcf_name
                    config.no_sort = True
                    vcf_output_info_str += " WARN: unsorted and not indexed"
                    log.warning(f'Result will be unsorted and uncompressed')
                    # util.fatal_error_main('BGZ files require sorting, and thus are currently not
                    #                       supported for the given number of input files.')

        if config.dev_population_snf:
            from sniffles.result import CombineResultTmpFilePopulationSNF
            result_class = CombineResultTmpFilePopulationSNF

        for contig_str, contig_length in contig_lengths:
            task = parallel.CombineTask(
                id=task_id,
                contig=contig_str,
                start=0,
                end=contig_length - 1,
                assigned_process_id=None,
                sv_id=0,
                config=config,
                result_class=result_class,
                regions=config.regions_by_contig.get(contig_str)
            )
            tasks.extend(task.scatter())
            if contig_str not in contig_tasks_intervals:
                contig_tasks_intervals[contig_str] = []
            contig_tasks_intervals[contig_str].append((task.start, task.end, task))
            task_id = tasks[-1].id + 1

        log.info(f"Verified headers for {len(input_snfs_sample_ids)} .snf files.")

    if config.mode != "genotype_vcf" and config.vcf is not None:
        vcf_out.write_header(contig_lengths)
    elif config.mode == "genotype_vcf":
        vcf_out.rewrite_header_genotype(vcf_in.header_str)

    #
    # Start workers
    #
    if config.threads:
        for pnum in range(config.threads):
            processes.append(parallel.SnifflesWorker(process_id=pnum, config=config, tasks=tasks, recycle_hint=monitor))
    else:
        processes.append(parallel.SnifflesParentWorker(config=config, tasks=tasks))

    if config.vcf is not None and config.sort:
        task_id_calls = {}

    log.info("")
    if config.mode == "call_sample" or config.mode == "genotype_vcf":
        if config.input_is_cram:
            # CRAM file
            log.info(f"Analyzing alignments... (progress display disabled for CRAM input)")
        else:
            log.info(f"Analyzing {total_mapped} alignments total...")
    elif config.mode == "combine":
        log.info(f"Calling SVs across {len(config.snf_input_info)} samples ({total_mapped} candidates total)...")
    log.info("")

    #
    # Distribute analysis tasks to workers and collect results
    #
    analysis_start_time = time.monotonic()

    for p in processes:
        p.start()

    finished_tasks: list[parallel.Task] = []

    while any([p.run_parent() for p in processes if p.running]):
        time.sleep(0.1)

    for p in processes:
        p.finalize()
        finished_tasks.extend(p.finished_tasks)

    log.info(f"Took {time.monotonic() - analysis_start_time:.2f}s.")
    log.info("")

    # It is possible that workers have all exited without finishing all work
    # Check the tasks deque is actually empty and explode accordingly if not
    if len(tasks) > 0:
        # util.fatal_error_main(f"All workers have exited but work remains to be done, it is possible that all workers "
        #                        f"have been killed by your system due to memory constraints: <{tasks}>")
        log.warning(f"All workers have exited but work remains to be done, it is possible that all workers have been "
                    f"killed by your system due to memory constraints. Partial results will be written.")
        log.debug(f"Missing tasks: {[(t.id, t.sv_id, t.contig, t.regions, t.result) for t in tasks]}")

    finished_tasks.sort(key=lambda task: task.id)

    for t in finished_tasks:
        t.result.emit(vcf_out=vcf_out, snf_out=snf_out, **rkwargs)

    if config.dev_output_candidates and config.mode == "call_sample":
        from shutil import copyfileobj
        with open(config.dev_output_candidates, "w") as csv:
            csv.write('svtype,orientation_start,contig_start,pos_start,orientation_end,contig_end,pos_end,'
                      'filter,support_inline,support_split,support_ref\n')
            for t in finished_tasks:
                tmpfile = t.result.candidate_filename  # noqa
                with open(tmpfile, "r") as f:
                    copyfileobj(f, csv)
                os.unlink(tmpfile)

    if snf_out:
        snf_candidate_count = snf_out.write_results(config, contigs)
        snf_out.close()
        log.info(f"Wrote {snf_candidate_count} SV candidates to {config.snf} (for multi-sample calling).")

    if psnf_out:
        c = psnf_out.write_results(config, contigs)
        psnf_out.close()
        log.info(f'Wrote {c} SVs to population SNF.')

    if DEV_MONITOR_MEM:
        logging.debug(f"[DEV: Total memory usage={dbg_get_total_memory_usage_MB():.2f}MB]")

    if config.vcf is not None:
        vcf_out.close()
        if config.vcf_output_bgz:
            vcf_index_start_time = time.time()
            log.info(f"Generating index for {config.vcf}...")
            try:
                filename = pysam.tabix_index(config.uncompressed_vcf_name, preset="vcf", force=True)
            except:  # noqa
                log.exception(f'Error indexing VCF.')
            else:
                os.rename(filename, config.vcf)
                log.info(f"Indexing VCF output took {time.time() - vcf_index_start_time:.2f}s.")

    if (config.mode == "call_sample" or config.mode == "combine") and config.vcf is not None:
        log.info(f"Wrote {vcf_out.call_count} called SVs to {config.vcf} {vcf_output_info_str}")

    if monitor:
        log.debug(f'Stopping resource monitoring.')
        monitor.stop()


if __name__ == "__main__":
    processes = []

    try:
        logging.config.dictConfig({
            'version': 1,
            'formatters': {
                'default': {
                    'format': '%(asctime)s %(levelname)s %(name)s (%(process)d): %(message)s'
                }
            },
            'handlers': {
                'console': {
                    'class': 'logging.StreamHandler',
                    'formatter': 'default',
                    'stream': 'ext://sys.stdout',
                }
            },
            'loggers': {
                'sniffles.progress': {
                    'level': logging.WARNING,
                },
                'sniffles.vcf': {
                    'level': logging.INFO,
                }
            },
            'root': {
                'level': logging.INFO,
                'handlers': ['console'],
            },
            'disable_existing_loggers': False,
        })
    except (ValueError, TypeError, AttributeError, ImportError):
        logging.exception(f'Error configuring loggers.')

    try:
        Sniffles2_Main(processes)
    except (util.Sniffles2Exit, SystemExit) as exit_code:
        if len(processes):
            # Allow time for child process error messages to propagate
            print("Sniffles2Main: Shutting down workers")
            time.sleep(10)
        for proc in processes:
            try:
                proc.process.terminate()
            except:
                pass

        for proc in processes:
            try:
                proc.process.join()
            except:
                pass
        exit(exit_code.code)
    except:
        logging.getLogger('sniffles.main').exception(f'Unhandled error while running sniffles.')
        exit(1)
