#!python
#
#     Copyright (C) 2017–2025, Jose Manuel Martí Martínez
#
#     This program is free software: you can redistribute it and/or modify
#     it under the terms of the GNU Affero General Public License as
#     published by the Free Software Foundation, either version 3 of the
#     License, or (at your option) any later version.
#
#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#     GNU Affero General Public License for more details.
#
#     You should have received a copy of the GNU Affero General Public License
#     along with this program. If not, see <https://www.gnu.org/licenses/>.
#
"""
Extract entries from fasta file formatted with accessions as NCBI nt database.
"""
# pylint: disable=no-name-in-module, not-an-iterable
import argparse
import gzip
import os
import sys
import time
from collections import Counter
from typing import Callable, List, Set, TextIO

from Bio import SeqIO

from recentrifuge import __version__, __author__, __date__
from recentrifuge.config import Filename, Id, GZEXT
from recentrifuge.config import LICENSE, NODES_FILE, NAMES_FILE, TAXDUMP_PATH
from recentrifuge.config import gray, red, green, cyan, magenta, blue, yellow
from recentrifuge.rank import Rank, Ranks, TaxLevels
from recentrifuge.taxonomy import Taxonomy
from recentrifuge.trees import TaxTree

# Internal constants
MAX_LENGTH_TAXID_LIST = 32


def main():
    """Main entry point to script."""

    def configure_parser():
        """Argument Parser Configuration"""
        parser = argparse.ArgumentParser(
            description='Extraction from fasta with accessions as NCBI nt DB',
            epilog=f'%(prog)s  - Release {__version__} - {__date__}' + LICENSE,
            formatter_class=argparse.RawDescriptionHelpFormatter
        )
        parser.add_argument(
            '-d', '--debug',
            action='store_true',
            help='increase output verbosity and perform additional checks'
        )
        parser.add_argument(
            '-l', '--limit',
            action='store',
            metavar='NUMBER',
            type=int,
            default=None,
            help=('limit of nt DB entries to extract; '
                  'default: no limit')
        )
        parser.add_argument(
            '-e', '--entrymax',
            action='store',
            metavar='NUMBER',
            type=int,
            default=None,
            help=('maximum number of nt DB entries to search for the taxa; '
                  'default: no maximum')
        )
        parser.add_argument(
            '-n', '--nodespath',
            action='store',
            metavar='PATH',
            default=TAXDUMP_PATH,
            help=('path for the nodes information files '
                  '(nodes.dmp and names.dmp from NCBI)')
        )
        parser_selection = parser.add_argument_group(
            'selection',
            'Selection of reads based on the taxonomy')
        parser_selection.add_argument(
            '-i', '--include',
            action='append',
            metavar='TAXID',
            type=Id,
            default=[],
            help=('NCBI taxid code to include a taxon and all underneath ' +
                  '(multiple -i is available to include several taxid); ' +
                  'by default all the taxa is considered for inclusion')
        )
        parser_selection.add_argument(
            '-x', '--exclude',
            action='append',
            metavar='TAXID',
            type=Id,
            default=[],
            help=('NCBI taxid code to exclude a taxon and all underneath ' +
                  '(multiple -x is available to exclude several taxid)')
        )
        parser_in = parser.add_argument_group(
            'input', 'Define input files')
        parser.add_argument(
            '-m', '--mapfile',
            action='store',
            metavar='FILE',
            required=True,
            help='Mapping (accession to taxid) file'
        )
        parser_in.add_argument(
            '-f', '--ntfastafile',
            action='store',
            metavar='FILE',
            default=None,
            help='NCBI nt formatted FASTA file'
        )
        parser.add_argument(
            '-c', '--compress',
            action='store_true',
            help='Output FASTA file will be gzipped'
        )
        parser.add_argument(
            '-V', '--version',
            action='version',
            version=f'%(prog)s release {__version__} ({__date__})'
        )
        return parser

    def check_debug():
        """Check debugging mode"""
        if args.debug:
            print(blue('INFO:'), gray('Debugging mode activated'))
            print(blue('INFO:'), gray('Active parameters:'))
            for key, val in vars(args).items():
                if val is not None and val is not False and val != []:
                    print(gray(f'\t{key} ='), f'{val}')

    def is_gzipped(fpath: Filename):
        """Check if a file exists and is gzipped"""
        try:
            with open(fpath, 'rb') as ftest:
                return ftest.read(2) == b'\x1f\x8b'
        except FileNotFoundError:
            return False

    def build_out_filename(fname: Filename) -> Filename:
        """Auxiliary function to properly build the output filenames.

        Args:
            fname: Complete filename of the FASTA input file (gzipped or not)

        Returns: Filename of the rextraccnted FASTA output file
        """
        # Get filename and extension, accounting for 2 extensions,
        #  typically found in filename.fname.gz
        fastq_filename, fastq_ext = os.path.splitext(fname)
        if fastq_ext.casefold() == GZEXT:
            fastq_filename, fastq_ext = os.path.splitext(fastq_filename)
        if args.compress and not fastq_ext.casefold().endswith(GZEXT):
            fastq_ext += GZEXT
        output_list: List[str] = [fastq_filename, '_rxnt']
        if including:
            output_list.append('_incl')
            output_list.extend('_'.join(including))
        if excluding:
            output_list.append('_excl')
            output_list.extend('_'.join(excluding))
        output_list.append(fastq_ext)
        return Filename(''.join(output_list))

    # timing initialization
    start_time: float = time.time()
    # Program header
    print(f'\n=-= {sys.argv[0]} =-= v{__version__} - {__date__}'
          f' =-= by {__author__} =-=\n')
    sys.stdout.flush()

    # Parse arguments
    argparser = configure_parser()
    args = argparser.parse_args()
    mapfile = args.mapfile
    nodesfile: Filename = Filename(os.path.join(args.nodespath, NODES_FILE))
    namesfile: Filename = Filename(os.path.join(args.nodespath, NAMES_FILE))
    excluding: Set[Id] = set(args.exclude)
    including: Set[Id] = set(args.include)
    ntfile: Filename = args.ntfastafile

    check_debug()

    # Load NCBI nodes, names and build children
    ncbi: Taxonomy = Taxonomy(nodesfile, namesfile, None,
                              False, excluding, including)

    # Build taxonomy tree
    print(gray('Building taxonomy tree...'), end='')
    sys.stdout.flush()
    tree = TaxTree()
    tree.grow(ontology=ncbi, look_ancestors=False)
    print(green(' OK!'))

    # Get the taxa
    print(gray('Filtering taxa...'), end='')
    sys.stdout.flush()
    ranks: Ranks = Ranks({})
    tree.get_taxa(ranks=ranks,
                  include=including,
                  exclude=excluding)
    print(green(' OK!'))
    taxids: Set[Id] = set(ranks)
    taxlevels: TaxLevels = Rank.ranks_to_taxlevels(ranks)
    num_taxlevels = Counter({rank: len(taxlevels[rank])
                             for rank in taxlevels})
    num_taxlevels = +num_taxlevels

    # Statistics about including taxa
    print(f'  {len(taxids)}', gray('taxid selected in '), end='')
    print(f'{len(num_taxlevels)}', gray('different taxonomic levels:'))
    for rank in num_taxlevels:
        print(f'  Number of different {rank}: {num_taxlevels[rank]}')
    assert taxids, red('ERROR! No taxids to search for!')

    # Get the accessions from the mapping file
    accessions: List[str] = []
    num_maps: int = 0
    # timing initialization
    start_time_load: float = time.perf_counter()
    print(gray('Parsing acc->tid mapping file'), mapfile,
          gray('...\nMmaps: '), end='')
    sys.stdout.flush()
    try:
        with open(mapfile, 'rt', newline=None) as file:
            for num_maps, mapp in enumerate(file):
                acc, tid = mapp.strip().split(maxsplit=1)
                if not num_maps % 10_000_000:
                    print(f'{num_maps // 1_000_000:_d}', end='')
                    sys.stdout.flush()
                elif not num_maps % 1_000_000:
                    print('_', end='')
                    sys.stdout.flush()
                if tid not in taxids:
                    continue
                else:
                    accessions.append(acc)
    except (FileNotFoundError, ValueError):
        raise Exception(red('ERROR!') + 'Cannot process "' + mapfile + '"')
    print(green(' OK!'))

    # Basic accessions statistics (or exit if no matching read)
    print(gray('  Parse elapsed time: ') +
          f'{time.perf_counter() - start_time_load:.3g}' + gray(' sec'))
    if num_maps:
        print(gray(f'  Selected accessions: '), f'{len(accessions):_d}\t',
              gray('('), f'{len(accessions) / num_maps:.4%}', gray('of DB )'))
        sys.stdout.flush()
    else:
        print(yellow('WARNING!'), 'No matching accession found! Exiting.')
        sys.exit(1)

    # Deals with NCBI nt formatted FASTA file
    records_ids: Set[str] = set(accessions)
    extracted: int = 0
    nts: int = 0
    i: int = 0
    print(gray('Processing DB (nt FASTA)'), f'{ntfile}', gray('...\nMseqs: '),
          end='')
    sys.stdout.flush()
    outfilename: Filename = build_out_filename(ntfile)
    outhandler: Callable[..., TextIO]
    if args.compress:
        outhandler = gzip.open
    else:
        outhandler = open
    dbhandler: Callable[..., TextIO]
    if is_gzipped(ntfile):
        dbhandler = gzip.open
    else:
        dbhandler = open
    try:
        with dbhandler(ntfile, 'rt') as nt, \
                outhandler(outfilename, 'wt') as out:
            for i, rec in enumerate(SeqIO.parse(nt, 'fasta')):
                if not records_ids:
                    print(green(' [all accessions found]'), end='')
                    break
                elif args.entrymax and i >= args.entrymax:
                    print(yellow(' [stopping by entrymax limit]!'), end='')
                    break
                elif args.limit and extracted >= args.limit:
                    print(yellow(f' [stopping by {extracted} limit]!'),
                          end='')
                    break
                elif not i % 1_000_000:
                    print(f'{i // 1_000_000:_d}', end='')
                    sys.stdout.flush()
                elif not i % 100_000:
                    print('.', end='')
                    sys.stdout.flush()
                try:
                    records_ids.remove(rec.id)
                except KeyError:
                    pass
                else:
                    SeqIO.write(rec, out, 'fasta-2line')
                    nts += len(rec.seq)
                    extracted += 1
    except FileNotFoundError:
        raise Exception('\n\033[91mERROR!\033[0m Cannot read nt DB FASTA file')
    print(cyan(f' {i / 1e+6:.3g} Mseqs'), green('OK! '))
    print(gray('Wrote'), magenta(f'{extracted}'), gray('extracted entries ('),
          magenta(f'{nts / 1e+9:.3g}'), gray('Gnucs and'),
          magenta(f'{extracted / i:.4%}'), gray('of DB )\n\t'),
          magenta('gzipped' if args.compress else ''),
          gray('in'), outfilename)

    # Timing results
    print(gray('Total elapsed time:'), time.strftime(
        "%H:%M:%S", time.gmtime(time.time() - start_time)))


if __name__ == '__main__':
    main()
