#!python
#
#     Copyright (C) 2017–2024, Jose Manuel Martí Martínez
#
#     This program is free software: you can redistribute it and/or modify
#     it under the terms of the GNU Affero General Public License as
#     published by the Free Software Foundation, either version 3 of the
#     License, or (at your option) any later version.
#
#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#     GNU Affero General Public License for more details.
#
#     You should have received a copy of the GNU Affero General Public License
#     along with this program. If not, see <https://www.gnu.org/licenses/>.
#
"""
Selectively extract reads following Centrifuge/Kraken output.
"""
# pylint: disable=no-name-in-module, not-an-iterable
import argparse
import gzip
import os
import sys
import time
from collections import Counter
from typing import Callable, List, Set, TextIO

from Bio import SeqIO, SeqRecord

from recentrifuge import __version__, __author__, __date__
from recentrifuge.config import Filename, Id, Score, GZEXT
from recentrifuge.config import LICENSE, NODES_FILE, NAMES_FILE, TAXDUMP_PATH
from recentrifuge.config import gray, red, green, cyan, magenta, blue, yellow
from recentrifuge.rank import Rank, Ranks, TaxLevels
from recentrifuge.taxonomy import Taxonomy
from recentrifuge.trees import TaxTree

# Internal constants
MAX_LENGTH_TAXID_LIST = 32


def main():
    """Main entry point to script."""

    def configure_parser():
        """Argument Parser Configuration"""
        parser = argparse.ArgumentParser(
            description='Selectively extract reads by Centrifuge output',
            epilog=f'%(prog)s  - Release {__version__} - {__date__}' + LICENSE,
            formatter_class=argparse.RawDescriptionHelpFormatter
        )
        parser.add_argument(
            '-d', '--debug',
            action='store_true',
            help='increase output verbosity and perform additional checks'
        )
        parser.add_argument(
            '-f', '--file',
            action='store',
            metavar='FILE',
            required=True,
            help='Centrifuge output file'
        )
        parser.add_argument(
            '-l', '--limit',
            action='store',
            metavar='NUMBER',
            type=int,
            default=None,
            help=('limit of FASTQ reads to extract; '
                  'default: no limit')
        )
        parser.add_argument(
            '-m', '--maxreads',
            action='store',
            metavar='NUMBER',
            type=int,
            default=None,
            help=('maximum number of FASTQ reads to search for the taxa; '
                  'default: no maximum')
        )
        parser.add_argument(
            '-n', '--nodespath',
            action='store',
            metavar='PATH',
            default=TAXDUMP_PATH,
            help=('path for the nodes information files '
                  '(nodes.dmp and names.dmp from NCBI)')
        )
        parser_selection = parser.add_argument_group(
            'selection',
            'Selection of reads based on the classification')
        parser_classif = parser_selection.add_mutually_exclusive_group(
            required=False)
        parser_classif.add_argument(
            '-u', '--unclassified',
            action='store_true',
            help='Just extract unclassified reads (overrides other options)'
        )
        parser_classif.add_argument(
            '-i', '--include',
            action='append',
            metavar='TAXID',
            type=Id,
            default=[],
            help=('NCBI taxid code to include a taxon and all underneath ' +
                  '(multiple -i is available to include several taxid); ' +
                  'by default all the taxa is considered for inclusion')
        )
        parser_selection.add_argument(
            '-x', '--exclude',
            action='append',
            metavar='TAXID',
            type=Id,
            default=[],
            help=('NCBI taxid code to exclude a taxon and all underneath ' +
                  '(multiple -x is available to exclude several taxid)')
        )
        parser.add_argument(
            '-y', '--minscore',
            action='store',
            metavar='NUMBER',
            type=lambda txt: Score(float(txt)),
            default=None,
            help=('minimum score/confidence of the classification of a read '
                  'to pass the quality filter; all pass by default')
        )
        parser_in = parser.add_argument_group(
            'input', 'Define Rextract input files')
        filein = parser_in.add_mutually_exclusive_group(required=True)
        filein.add_argument(
            '-q', '--fastq',
            action='store',
            metavar='FILE',
            default=None,
            help='single FASTQ file (no paired-ends), which may be gzipped'
        )
        filein.add_argument(
            '-1', '--mate1',
            action='store',
            metavar='FILE',
            default=None,
            help='paired-ends FASTQ file (gzipped or not) for mate 1s '
                 '(filename usually includes _1)'
        )
        parser_in.add_argument(
            '-2', '--mate2',
            action='store',
            metavar='FILE',
            default=None,
            help='paired-ends FASTQ file (gzipped or not) for mate 2s '
                 '(filename usually includes _2)'
        )
        parser.add_argument(
            '-c', '--compress',
            action='store_true',
            help='any generated FASTQ file will be gzipped'
        )
        parser.add_argument(
            '-V', '--version',
            action='version',
            version=f'%(prog)s release {__version__} ({__date__})'
        )
        return parser

    def check_debug():
        """Check debugging mode"""
        if args.debug:
            print(blue('INFO:'), gray('Debugging mode activated'))
            print(blue('INFO:'), gray('Active parameters:'))
            for key, val in vars(args).items():
                if val is not None and val is not False and val != []:
                    print(gray(f'\t{key} ='), f'{val}')

    # timing initialization
    start_time: float = time.time()
    # Program header
    print(f'\n=-= {sys.argv[0]} =-= v{__version__} - {__date__}'
          f' =-= by {__author__} =-=\n')
    sys.stdout.flush()

    # Parse arguments
    argparser = configure_parser()
    args = argparser.parse_args()
    output_file = args.file
    nodesfile: Filename = Filename(os.path.join(args.nodespath, NODES_FILE))
    namesfile: Filename = Filename(os.path.join(args.nodespath, NAMES_FILE))
    unclass: bool = args.unclassified
    excluding: Set[Id] = set(args.exclude)
    including: Set[Id] = set(args.include)
    fastq_1: Filename
    fastq_2: Filename = args.mate2
    if not fastq_2:
        fastq_1 = args.fastq
    else:
        fastq_1 = args.mate1

    check_debug()

    # Load NCBI nodes, names and build children
    if not unclass:
        plasmidfile: Filename = None
        ncbi: Taxonomy = Taxonomy(nodesfile, namesfile, plasmidfile,
                                  False, excluding, including)

        # Build taxonomy tree
        print(gray('Building taxonomy tree...'), end='')
        sys.stdout.flush()
        tree = TaxTree()
        tree.grow(ontology=ncbi, look_ancestors=False)
        print(green(' OK!'))

        # Get the taxa
        print(gray('Filtering taxa...'), end='')
        sys.stdout.flush()
        ranks: Ranks = Ranks({})
        tree.get_taxa(ranks=ranks,
                      include=including,
                      exclude=excluding)
        print(green(' OK!'))
        taxids: Set[Id] = set(ranks)
        taxlevels: TaxLevels = Rank.ranks_to_taxlevels(ranks)
        num_taxlevels = Counter({rank: len(taxlevels[rank])
                                 for rank in taxlevels})
        num_taxlevels = +num_taxlevels

        # Statistics about including taxa
        print(f'  {len(taxids)}', gray('taxid selected in '), end='')
        print(f'{len(num_taxlevels)}', gray('different taxonomic levels:'))
        for rank in num_taxlevels:
            print(f'  Number of different {rank}: {num_taxlevels[rank]}')
        assert taxids, red('ERROR! No taxids to search for!')

    # Get the records
    records: List[SeqRecord.SeqRecord] = []
    num_seqs: int = 0
    # timing initialization
    start_time_load: float = time.perf_counter()
    print(gray('Parsing Centrifuge output file'), output_file, gray('...'),
          end='')
    sys.stdout.flush()
    try:
        with open(output_file, 'r', newline=None) as file:
            file.readline()  # discard header
            for num_seqs, record in enumerate(SeqIO.parse(file, 'centrifuge')):
                if unclass:
                    if record.dbxrefs[0] != 'unclassified':
                        continue  # Ignore read if classified
                    else:
                        records.append(record)
                else:
                    tid: Id = record.annotations['taxID']
                    if tid not in taxids:
                        continue
                    score: Score = Score(record.annotations['score'])
                    if args.minscore is not None and score < args.minscore:
                        continue  # Ignore read if low confidence
                    else:
                        records.append(record)
    except FileNotFoundError:
        raise Exception(red('ERROR!') + 'Cannot read "' +
                        output_file + '"')
    print(green(' OK!'))

    # Basic records statistics (or exit if no matching read)
    print(gray('  Parse elapsed time: ') +
          f'{time.perf_counter() - start_time_load:.3g}' + gray(' sec'))
    if num_seqs:
        modifier: str = ('Unclassified' if unclass else 'Matching')
        print(gray(f'  {modifier} reads: '), f'{len(records):_d}\t', gray('('),
              f'{len(records) / num_seqs:.4%}', gray('of sample )'))
        sys.stdout.flush()
    else:
        print(red('ERROR!'), 'No matching read found on Centrifuge output!')
        sys.exit(1)

    def is_gzipped(fpath: Filename):
        """Check if a file exists and is gzipped"""
        try:
            with open(fpath, 'rb') as ftest:
                return ftest.read(2) == b'\x1f\x8b'
        except FileNotFoundError:
            return False

    # FASTQ sequence dealing
    records_ids: Set[SeqRecord.SeqRecord] = {record.id for record in records}
    seqs1: List[SeqRecord.SeqRecord] = []
    seqs2: List[SeqRecord.SeqRecord] = []
    extracted: int = 0
    failed: int = 0
    i: int = 0
    if fastq_2:
        print(gray('Loading FASTQ files'), fastq_1, gray('and'), fastq_2,
              gray('...\nMseqs: '), end='')
        sys.stdout.flush()
        mate1handler: Callable[..., TextIO]
        mate2handler: Callable[..., TextIO]
        if is_gzipped(fastq_1):
            mate1handler = gzip.open
        else:
            mate1handler = open
        if is_gzipped(fastq_2):
            mate2handler = gzip.open
        else:
            mate2handler = open
        try:
            with mate1handler(fastq_1, 'rt') as file1, \
                 mate2handler(fastq_2, 'rt') as file2:
                for i, (rec1, rec2) in enumerate(zip(SeqIO.parse(file1,
                                                                 'quickfastq'),
                                                     SeqIO.parse(file2,
                                                                 'quickfastq'))
                                                 ):
                    if not records_ids:
                        print(green(' [all records found]'), end='')
                        break
                    elif args.maxreads and i >= args.maxreads:
                        print(yellow(' [stopping by maxreads limit!]'), end='')
                        break
                    elif args.limit and extracted >= args.limit:
                        print(yellow(f' [stopping by {extracted} limit!]'),
                              end='')
                        break
                    elif not i % 1000000:
                        print(f'{i // 1000000:_d}', end='')
                        sys.stdout.flush()
                    elif not i % 100000:
                        print('.', end='')
                        sys.stdout.flush()
                    try:
                        records_ids.remove(rec1.id)
                    except KeyError:
                        try: # Try again for labels like SIM3.1-1999998(/1|/2)
                            records_ids.remove(rec1.id.split('/')[0])
                        except KeyError:
                            failed += 1
                            continue
                    seqs1.append(rec1)
                    seqs2.append(rec2)
                    extracted += 1

        except FileNotFoundError:
            raise Exception('\n\033[91mERROR!\033[0m Cannot read FASTQ files')
    else:
        print(gray('Loading FASTQ file'), f'{fastq_1}', gray('...\nMseqs: '),
              end='')
        sys.stdout.flush()
        mate1handler: Callable[..., TextIO]
        if is_gzipped(fastq_1):
            mate1handler = gzip.open
        else:
            mate1handler = open
        try:
            with mate1handler(fastq_1, 'rt') as file1:
                for i, rec1 in enumerate(SeqIO.parse(file1, 'quickfastq')):
                    if not records_ids:
                        print(green(' [all records found]'), end='')
                        break
                    elif args.maxreads and i >= args.maxreads:
                        print(yellow(' [stopping by maxreads limit]!'), end='')
                        break
                    elif args.limit and extracted >= args.limit:
                        print(yellow(f' [stopping by {extracted} limit]!'),
                              end='')
                        break
                    elif not i % 1000000:
                        print(f'{i // 1000000:_d}', end='')
                        sys.stdout.flush()
                    elif not i % 100000:
                        print('.', end='')
                        sys.stdout.flush()
                    try:
                        records_ids.remove(rec1.id)
                    except KeyError:
                        failed += 1
                    else:
                        seqs1.append(rec1)
                        extracted += 1
        except FileNotFoundError:
            raise Exception('\n\033[91mERROR!\033[0m Cannot read FASTQ file')
    print(cyan(f' {i / 1e+6:.3g} Mseqs'), green('OK! '))
    print(extracted, gray('successfully extracted reads ('),
          f'{extracted/i:.2%}', gray(') and'), failed,
          gray('failed extracted reads ('), f'{failed/i:.2%}', gray(')'), '\n')
    if not extracted:
        print(yellow('ERROR!'), 'No matching read(s) in the FASTQ file(s)!')
        sys.exit(2)
    elif failed:
        print(yellow('WARNING!'), 'Some reads missing in the FASTQ file(s)!')
        
    sys.stdout.flush()


    def format_filename(fastq: Filename) -> Filename:
        """Auxiliary function to properly format the output filenames.

        Args:
            fastq: Complete filename of the fastq input file (gzipped or not)

        Returns: Filename of the rextracted fastq output file
        """
        # Get filename and extension, accounting for 2 extensions,
        #  typically found in filename.fastq.gz
        fastq_filename, fastq_ext = os.path.splitext(fastq)
        if fastq_ext.casefold() == GZEXT:
            fastq_filename, fastq_ext = os.path.splitext(fastq_filename)
        if args.compress and not fastq_ext.casefold().endswith(GZEXT):
            fastq_ext += GZEXT
        output_list: List[str] = [fastq_filename, '_rxtr']
        if unclass:
            output_list.append('_unclass')
        else:
            if including:
                output_list.append('_incl')
                output_list.extend('_'.join(including))
            if excluding:
                output_list.append('_excl')
                output_list.extend('_'.join(excluding))
        output_list.append(fastq_ext)
        return Filename(''.join(output_list))

    def write_seqs(seqs: List[SeqRecord.SeqRecord], fname_in: Filename,
                   gzipped: bool = False) -> None:
        """Aux function to save the extracted sequences"""
        fname_out: Filename = format_filename(fname_in)
        if gzipped:
            with gzip.open(fname_out, 'wt') as fgz:
                SeqIO.write(seqs, fgz, 'quickfastq')
            print(gray('Compressed'), magenta(f'{len(seqs)}'),
                  gray('reads in'), fname_out)
        else:
            SeqIO.write(seqs, fname_out, 'quickfastq')
            print(gray('Wrote'), magenta(f'{len(seqs)}'),
                  gray('reads in'), fname_out)

    write_seqs(seqs1, fastq_1, gzipped=args.compress)
    if fastq_2:
        write_seqs(seqs2, fastq_2, gzipped=args.compress)

    # Timing results
    print(gray('Total elapsed time:'), time.strftime(
        "%H:%M:%S", time.gmtime(time.time() - start_time)))


if __name__ == '__main__':
    main()
