#!python
#
#     Copyright (C) 2017–2026, Jose Manuel Martí Martínez
#
#     This program is free software: you can redistribute it and/or modify
#     it under the terms of the GNU Affero General Public License as
#     published by the Free Software Foundation, either version 3 of the
#     License, or (at your option) any later version.
#
#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#     GNU Affero General Public License for more details.
#
#     You should have received a copy of the GNU Affero General Public License
#     along with this program. If not, see <https://www.gnu.org/licenses/>.
#
"""
Selectively extract reads following Centrifuge/Kraken output.
Support for both FASTQ and FASTA sequence files.
"""
# pylint: disable=no-name-in-module, not-an-iterable
import argparse
import gzip
import os
import sys
import time
from collections import Counter
from typing import Callable, List, Optional, Set, TextIO

from Bio import SeqIO, SeqRecord

from recentrifuge import __version__, __author__, __date__
from recentrifuge.config import Filename, Id, Score, GZEXT
from recentrifuge.config import LICENSE, NODES_FILE, NAMES_FILE, TAXDUMP_PATH
from recentrifuge.config import gray, red, green, cyan, magenta, blue, yellow
from recentrifuge.rank import Rank, Ranks, TaxLevels
from recentrifuge.taxonomy import Taxonomy
from recentrifuge.trees import TaxTree

# Internal constants
MAX_LENGTH_TAXID_LIST = 32


def main():
    """Main entry point to script."""

    def configure_parser():
        """Argument Parser Configuration"""
        parser = argparse.ArgumentParser(
            description='Selectively extract reads by Centrifuge output',
            epilog=f'%(prog)s  - Release {__version__} - {__date__}' + LICENSE,
            formatter_class=argparse.RawDescriptionHelpFormatter
        )
        parser.add_argument(
            '-d', '--debug',
            action='store_true',
            help='increase output verbosity and perform additional checks'
        )
        parser.add_argument(
            '-f', '--file',
            action='store',
            metavar='FILE',
            required=True,
            help='Centrifuge output file'
        )
        parser.add_argument(
            '-l', '--limit',
            action='store',
            metavar='NUMBER',
            type=int,
            default=None,
            help=('limit of sequence reads to extract; '
                  'default: no limit')
        )
        parser.add_argument(
            '-m', '--maxreads',
            action='store',
            metavar='NUMBER',
            type=int,
            default=None,
            help=('maximum number of sequence reads to search for the taxa; '
                  'default: no maximum')
        )
        parser.add_argument(
            '-n', '--nodespath',
            action='store',
            metavar='PATH',
            default=TAXDUMP_PATH,
            help=('path for the nodes information files '
                  '(nodes.dmp and names.dmp from NCBI)')
        )
        parser_selection = parser.add_argument_group(
            'selection',
            'Selection of reads based on the classification')
        parser_classif = parser_selection.add_mutually_exclusive_group(
            required=False)
        parser_classif.add_argument(
            '-u', '--unclassified',
            action='store_true',
            help='Just extract unclassified reads (overrides other options)'
        )
        parser_classif.add_argument(
            '-i', '--include',
            action='append',
            metavar='TAXID',
            type=Id,
            default=[],
            help=('NCBI taxid code to include a taxon and all underneath ' +
                  '(multiple -i is available to include several taxid); ' +
                  'by default all the taxa is considered for inclusion')
        )
        parser_selection.add_argument(
            '-x', '--exclude',
            action='append',
            metavar='TAXID',
            type=Id,
            default=[],
            help=('NCBI taxid code to exclude a taxon and all underneath ' +
                  '(multiple -x is available to exclude several taxid)')
        )
        parser.add_argument(
            '-y', '--minscore',
            action='store',
            metavar='NUMBER',
            type=lambda txt: Score(float(txt)),
            default=None,
            help=('minimum score/confidence of the classification of a read '
                  'to pass the quality filter; all pass by default')
        )
        parser.add_argument(
            '-z', '--maxscore',
            action='store',
            metavar='NUMBER',
            type=lambda txt: Score(float(txt)),
            default=None,
            help=('maximum score/confidence of the classification of a read '
                  'to pass the quality filter; all pass by default')
        )
        parser_in = parser.add_argument_group(
            'input', 'Define Rextract input files')
        filein = parser_in.add_mutually_exclusive_group(required=True)
        filein.add_argument(
            '-q', '--fastq',
            action='store',
            metavar='FILE',
            default=None,
            help='single sequence file (no paired-ends), which may be gzipped'
        )
        filein.add_argument(
            '-1', '--mate1',
            action='store',
            metavar='FILE',
            default=None,
            help='paired-ends sequence file (gzipped or not) for mate 1s '
                 '(filename usually includes _1)'
        )
        parser_in.add_argument(
            '-2', '--mate2',
            action='store',
            metavar='FILE',
            default=None,
            help='paired-ends sequence file (gzipped or not) for mate 2s '
                 '(filename usually includes _2)'
        )
        parser.add_argument(
            '-c', '--compress',
            action='store_true',
            help='any generated sequence file will be gzipped'
        )
        parser.add_argument(
            '-a', '--fasta',
            action='store_true',
            help='treat all the input and output sequence files as FASTA '
                 'instead of the default FASTQ'
        )
        parser.add_argument(
            '-V', '--version',
            action='version',
            version=f'%(prog)s release {__version__} ({__date__})'
        )
        return parser

    def check_debug():
        """Check debugging mode"""
        if args.debug:
            print(blue('INFO:'), gray('Debugging mode activated'))
            print(blue('INFO:'), gray('Active parameters:'))
            for key, val in vars(args).items():
                if val is not None and val is not False and val != []:
                    print(gray(f'\t{key} ='), f'{val}')

    # timing initialization
    start_time: float = time.time()
    # Program header
    print(f'\n=-= {sys.argv[0]} =-= v{__version__} - {__date__}'
          f' =-= by {__author__} =-=\n')
    sys.stdout.flush()

    # Parse arguments
    argparser = configure_parser()
    args = argparser.parse_args()
    output_file = args.file
    nodesfile: Filename = Filename(os.path.join(args.nodespath, NODES_FILE))
    namesfile: Filename = Filename(os.path.join(args.nodespath, NAMES_FILE))
    unclass: bool = args.unclassified
    excluding: Set[Id] = set(args.exclude)
    including: Set[Id] = set(args.include)
    fast_1: Filename
    fast_2: Filename = args.mate2
    if not fast_2:
        fast_1 = args.fastq
    else:
        fast_1 = args.mate1

    check_debug()

    # Load NCBI nodes, names and build children
    taxids: Set[Id] = set()  # Initialize for type safety
    if not unclass:
        plasmidfile: Optional[Filename] = None
        ncbi: Taxonomy = Taxonomy(nodesfile, namesfile, plasmidfile,
                                  False, excluding, including)

        # Build taxonomy tree
        print(gray('Building taxonomy tree...'), end='')
        sys.stdout.flush()
        tree = TaxTree()
        tree.grow(ontology=ncbi, look_ancestors=False)
        print(green(' OK!'))

        # Get the taxa
        print(gray('Filtering taxa...'), end='')
        sys.stdout.flush()
        ranks: Ranks = Ranks({})
        tree.get_taxa(ranks=ranks,
                      include=including,
                      exclude=excluding)
        print(green(' OK!'))
        taxids = set(ranks)
        taxlevels: TaxLevels = Rank.ranks_to_taxlevels(ranks)
        num_taxlevels = Counter({rank: len(taxlevels[rank])
                                 for rank in taxlevels})
        num_taxlevels = +num_taxlevels

        # Statistics about including taxa
        print(f'  {len(taxids)}', gray('taxid selected in '), end='')
        print(f'{len(num_taxlevels)}', gray('different taxonomic levels:'))
        for rank in num_taxlevels:
            print(f'  Number of different {rank}: {num_taxlevels[rank]}')
        if not taxids:
            print(red('ERROR!'), 'No taxids to search for!')
            sys.exit(1)

    # Get the records
    records: List[SeqRecord.SeqRecord] = []
    num_seqs: int = 0
    # timing initialization
    start_time_load: float = time.perf_counter()
    print(gray('Parsing Centrifuge output file'), output_file, gray('...'),
          end='')
    sys.stdout.flush()
    try:
        with open(output_file, 'r', newline=None) as file:
            file.readline()  # discard header
            for num_seqs, record in enumerate(SeqIO.parse(file, 'centrifuge')):
                if unclass:
                    if record.dbxrefs[0] != 'unclassified':
                        continue  # Ignore read if classified
                    else:
                        records.append(record)
                else:
                    tid: Id = record.annotations['taxID']
                    if tid not in taxids:
                        continue
                    score: Score = Score(record.annotations['score'])
                    if args.minscore is not None and score < args.minscore:
                        continue  # Discard read if classification score below threshold
                    if args.maxscore is not None and score > args.maxscore:
                        continue  # Discard read if classification score above threshold
                    records.append(record)
    except FileNotFoundError:
        raise Exception(red('ERROR!') + 'Cannot read "' +
                        output_file + '"')
    print(green(' OK!'))

    # Basic records statistics (or exit if no matching read)
    print(gray('  Parse elapsed time: ') +
          f'{time.perf_counter() - start_time_load:.3g}' + gray(' sec'))
    if num_seqs:
        modifier: str = ('Unclassified' if unclass else 'Matching')
        print(gray(f'  {modifier} reads: '), f'{len(records):_d}\t', gray('('),
              f'{len(records) / num_seqs:.4%}', gray('of sample )'))
        sys.stdout.flush()
    else:
        print(red('ERROR!'), 'No matching read found on Centrifuge output!')
        sys.exit(1)

    def is_gzipped(fpath: Filename):
        """Check if a file exists and is gzipped"""
        try:
            with open(fpath, 'rb') as ftest:
                return ftest.read(2) == b'\x1f\x8b'
        except FileNotFoundError:
            return False

    # FASTQ/FASTA sequence dealing
    records_ids: Set[str] = {record.id for record in records
                             if record.id is not None}
    seqs1: List[SeqRecord.SeqRecord] = []
    seqs2: List[SeqRecord.SeqRecord] = []
    extracted: int = 0
    i: int = 0
    seq_format = 'quickfasta' if args.fasta else 'quickfastq'
    seq_label = 'FASTA' if args.fasta else 'FASTQ'
    if fast_2:
        print(gray(f'Loading {seq_label} files'), fast_1, gray('and'), fast_2,
              gray('...\nMseqs: '), end='')
        sys.stdout.flush()
        mate1handler: Callable[..., TextIO] = open
        mate2handler: Callable[..., TextIO] = open
        if is_gzipped(fast_1):
            mate1handler = gzip.open
        else:
            mate1handler = open
        if is_gzipped(fast_2):
            mate2handler = gzip.open
        else:
            mate2handler = open
        try:
            with mate1handler(fast_1, 'rt') as file1, \
                 mate2handler(fast_2, 'rt') as file2:
                for i, (rec1, rec2) in enumerate(zip(SeqIO.parse(file1,
                                                                 seq_format),
                                                      SeqIO.parse(file2,
                                                                  seq_format))
                                                 ):
                    if not records_ids:
                        print(green(' [all records found]'), end='')
                        break
                    elif args.maxreads and i >= args.maxreads:
                        print(yellow(' [stopping by maxreads limit!]'), end='')
                        break
                    elif args.limit and extracted >= args.limit:
                        print(yellow(f' [stopping by {extracted} limit!]'),
                              end='')
                        break
                    elif not i % 1000000:
                        print(f'{i // 1000000:_d}', end='')
                        sys.stdout.flush()
                    elif not i % 100000:
                        print('.', end='')
                        sys.stdout.flush()
                    if rec1.id in records_ids:
                        records_ids.remove(rec1.id)
                    elif rec1.id.split('/')[0] in records_ids:
                        # Handle labels like SIM3.1-1999998(/1|/2)
                        records_ids.remove(rec1.id.split('/')[0])
                    else:
                        continue  # Read not in our wanted set
                    seqs1.append(rec1)
                    seqs2.append(rec2)
                    extracted += 1

        except FileNotFoundError:
            raise Exception(f'\n\033[91mERROR!\033[0m Cannot read {seq_label} files')
    else:
        print(gray(f'Loading {seq_label} file'), f'{fast_1}', gray('...\nMseqs: '),
              end='')
        sys.stdout.flush()
        fq1handler: Callable[..., TextIO] = open
        if is_gzipped(fast_1):
            fq1handler = gzip.open
        try:
            with fq1handler(fast_1, 'rt') as file1:
                for i, rec1 in enumerate(SeqIO.parse(file1, seq_format)):
                    if not records_ids:
                        print(green(' [all records found]'), end='')
                        break
                    elif args.maxreads and i >= args.maxreads:
                        print(yellow(' [stopping by maxreads limit]!'), end='')
                        break
                    elif args.limit and extracted >= args.limit:
                        print(yellow(f' [stopping by {extracted} limit]!'),
                              end='')
                        break
                    elif not i % 1000000:
                        print(f'{i // 1000000:_d}', end='')
                        sys.stdout.flush()
                    elif not i % 100000:
                        print('.', end='')
                        sys.stdout.flush()
                    if rec1.id in records_ids:
                        records_ids.remove(rec1.id)
                        seqs1.append(rec1)
                        extracted += 1
        except FileNotFoundError:
            raise Exception(f'\n\033[91mERROR!\033[0m Cannot read {seq_label} file')
    print(cyan(f' {i / 1e+6:.3g} Mseqs'), green('OK! '))
    missing: int = len(records_ids)
    if i > 0:
        print(extracted, gray('successfully extracted reads ('),
              f'{extracted / i:.2%}', gray(')'))
    else:
        print(extracted, gray('successfully extracted reads'))
    if not extracted:
        print(red('ERROR!'), f'No matching read(s) in the {seq_label} file(s)!')
        sys.exit(2)
    if missing:
        print(yellow('WARNING!'), f'{missing} reads from Centrifuge output',
              f'not found in the {seq_label} file(s)!')
    sys.stdout.flush()


    def format_filename(fast: Filename) -> Filename:
        """Auxiliary function to properly format the output filenames.

        Args:
            fast: Complete filename of the sequence input file (gzipped or not)

        Returns: Filename of the rextracted sequence output file
        """
        # Get filename and extension, accounting for 2 extensions,
        #  typically found in filename.fasta.gz or .fastq.gz
        fast_filename, fast_ext = os.path.splitext(fast)
        if fast_ext.casefold() == GZEXT:
            fast_filename, fast_ext = os.path.splitext(fast_filename)
        if args.compress and not fast_ext.casefold().endswith(GZEXT):
            fast_ext += GZEXT
        output_list: List[str] = [fast_filename, '_rxtr']
        if unclass:
            output_list.append('_unclass')
        else:
            if including:
                output_list.append('_incl')
                output_list.extend('_'.join(including))
            if excluding:
                output_list.append('_excl')
                output_list.extend('_'.join(excluding))
        output_list.append(fast_ext)
        return Filename(''.join(output_list))

    def write_seqs(seqs: List[SeqRecord.SeqRecord], fname_in: Filename,
                   gzipped: bool = False) -> None:
        """Aux function to save the extracted sequences"""
        fname_out: Filename = format_filename(fname_in)
        if gzipped:
            with gzip.open(fname_out, 'wt') as fgz:
                SeqIO.write(seqs, fgz, seq_format)
            print(gray('Compressed'), magenta(f'{len(seqs)}'),
                  gray('reads in'), fname_out)
        else:
            SeqIO.write(seqs, fname_out, seq_format)
            print(gray('Wrote'), magenta(f'{len(seqs)}'),
                  gray('reads in'), fname_out)

    write_seqs(seqs1, fast_1, gzipped=args.compress)
    if fast_2:
        write_seqs(seqs2, fast_2, gzipped=args.compress)

    # Timing results
    print(gray('Total elapsed time:'), time.strftime(
        "%H:%M:%S", time.gmtime(time.time() - start_time)))


if __name__ == '__main__':
    main()
