#!python
#
#     Copyright (C) 2017–2026, Jose Manuel Martí Martínez
#
#     This program is free software: you can redistribute it and/or modify
#     it under the terms of the GNU Affero General Public License as
#     published by the Free Software Foundation, either version 3 of the
#     License, or (at your option) any later version.
#
#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#     GNU Affero General Public License for more details.
#
#     You should have received a copy of the GNU Affero General Public License
#     along with this program. If not, see <https://www.gnu.org/licenses/>.
#
"""
Symmetrically split a FASTA file into several files for parallel processing.
"""
# pylint: disable=no-name-in-module, not-an-iterable
import argparse
import collections
import gzip
import sys
import time
from math import log10, ceil
from typing import Callable, Dict, TextIO, Counter

from Bio.SeqIO.FastaIO import SimpleFastaParser

from recentrifuge import __version__, __author__, __date__
from recentrifuge.config import Filename, LICENSE, GZEXT
from recentrifuge.config import gray, red, green, cyan, magenta, blue, yellow


def main():
    """Main entry point to script."""

    def configure_parser():
        """Argument Parser Configuration"""
        parser = argparse.ArgumentParser(
            description='Symmetrically split a FASTA file into several files',
            epilog=f'%(prog)s  - Release {__version__} - {__date__}' + LICENSE,
            formatter_class=argparse.RawDescriptionHelpFormatter
        )
        parser.add_argument(
            '-d', '--debug',
            action='store_true',
            help='increase output verbosity and perform additional checks'
        )
        parser.add_argument(
            '-n', '--num',
            action='store',
            metavar='NUMBER',
            type=int,
            default=None,
            help='number of output FASTA files to generate'
        )
        parser.add_argument(
            '-i', '--input',
            action='store',
            metavar='FILE',
            type=Filename,
            default=None,
            help='single FASTA file (no paired-ends), which may be gzipped'
        )
        parser.add_argument(
            '-o', '--outprefix',
            action='store',
            metavar='PATH',
            type=Filename,
            default=None,
            help='path with prefix for the series of FASTA output files'
        )
        parser.add_argument(
            '-m', '--maxreads',
            action='store',
            metavar='NUMBER',
            type=int,
            default=None,
            help=('maximum number of FASTA reads to process; '
                  'default: no maximum')
        )
        parser.add_argument(
            '-c', '--compress',
            action='store_true',
            default=False,
            help='the resulting FASTA files will be gzipped'
        )
        parser.add_argument(
            '-V', '--version',
            action='version',
            version=f'%(prog)s release {__version__} ({__date__})'
        )
        return parser

    def check_debug():
        """Check debugging mode"""
        if args.debug:
            print(blue('INFO:'), gray('Debugging mode activated'))
            print(blue('INFO:'), gray('Active parameters:'))
            for key, val in vars(args).items():
                if val is not None and val is not False and val != []:
                    print(gray(f'\t{key} ='), f'{val}')

    # timing initialization
    start_time: float = time.time()
    # Program header
    print(f'\n=-= {sys.argv[0]} =-= v{__version__} - {__date__}'
          f' =-= by {__author__} =-=\n')
    sys.stdout.flush()

    # Parse arguments
    argparser = configure_parser()
    args = argparser.parse_args()
    num_files: int = args.num
    in_fasta: Filename = args.input
    out_root: Filename = args.outprefix
    compr = args.compress

    check_debug()

    def is_gzipped(fpath: Filename):
        """Check if a file exists and is gzipped"""
        try:
            with open(fpath, 'rb') as ftest:
                return ftest.read(2) == b'\x1f\x8b'
        except FileNotFoundError:
            return False

    # Prepare output
    handler: Callable[..., TextIO]
    fext: str = '.fa'
    if compr:
        handler = gzip.open
        fext += GZEXT
    else:
        handler = open
    nts: Counter[str] = collections.Counter()
    in_handler: Dict[str, TextIO] = {}
    omag: int = ceil(log10(num_files))  # Get order of magnitude ceil
    for n in range(num_files):
        key: str = f'{n:0>{omag}}'
        in_handler[key] = handler(Filename(out_root + key + fext), 'wt')
        nts[key] = 0  # Initialize counter so that most_common will work

    # Main loop: split FASTA file
    i: int = 0
    print(gray('Splitting FASTA file'), f'{in_fasta}', gray('...\nMseqs: '),
          end='')
    sys.stdout.flush()
    if is_gzipped(in_fasta):
        handler = gzip.open
    else:
        handler = open
    try:
        with handler(in_fasta, 'rt') as fa:
            target: str
            for i, (title, seq) in enumerate(SimpleFastaParser(fa)):
                if args.maxreads and i >= args.maxreads:
                    print(yellow(' [stopping by maxreads limit]!'), end='')
                    break
                elif not i % 1000000:
                    print(f'{i // 1000000:_d}', end='')
                    sys.stdout.flush()
                elif not i % 100000:
                    print('.', end='')
                    sys.stdout.flush()
                target = nts.most_common()[-1][0]
                in_handler[target].write(f'>{title}\n')
                in_handler[target].write(seq + '\n')
                nts[target] += len(seq)
    except FileNotFoundError:
        raise Exception(red('\nERROR!'), 'Cannot read FASTA file')
    print(cyan(f' {i / 1e+6:.3g} Mseqs'), green('OK! '))
    print(gray('Wrote'), magenta(f'{sum(nts.values()) / 1e+9:.3g}'),
          gray('Gnucs in'), num_files, magenta('gzipped' if compr else ''),
          gray('FASTA files.'))

    # Timing results
    print(gray('Total elapsed time:'), time.strftime(
        "%H:%M:%S", time.gmtime(time.time() - start_time)))


if __name__ == '__main__':
    main()
