#!/usr/bin/env python3
#
#     Copyright (C) 2017–2026, Jose Manuel Martí Martínez
#
#     This program is free software: you can redistribute it and/or modify
#     it under the terms of the GNU Affero General Public License as
#     published by the Free Software Foundation, either version 3 of the
#     License, or (at your option) any later version.
#
#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#     GNU Affero General Public License for more details.
#
#     You should have received a copy of the GNU Affero General Public License
#     along with this program. If not, see <https://www.gnu.org/licenses/>.
#
"""
Analyze metagenomic (taxonomic) classification data.
"""
# pylint: disable=no-name-in-module, not-an-iterable
import argparse
import bz2
import collections as col
import multiprocessing as mp
import os
import pickle
import platform
import sys
import time
from typing import Counter, List, Dict, Set, Callable, Tuple, Any, Optional

from recentrifuge import __version__, __author__, __date__
from recentrifuge.centrifuge import select_centrifuge_inputs
from recentrifuge.clark import select_clark_inputs
from recentrifuge.config import Filename, Sample, Id, Score, Scoring, Extra, UnionCounter, UnionScores
from recentrifuge.config import Summary
from recentrifuge.config import HTML_SUFFIX, PKL_SUFFIX, TAXDUMP_PATH
from recentrifuge.config import STATS_SHEET_NAME, LICENSE, Err, Classifier
from recentrifuge.config import NODES_FILE, NAMES_FILE, PLASMID_FILE
from recentrifuge.config import STR_CONTROL, STR_EXCLUSIVE, STR_SHARED
from recentrifuge.config import STR_CONTROL_SHARED
from recentrifuge.config import gray, red, green, yellow, blue, magenta
from recentrifuge.core import process_rank, summarize_analysis
from recentrifuge.generic import GenericFormat, select_generic_inputs
from recentrifuge.kraken import select_kraken_inputs
from recentrifuge.krona import COUNT, UNASSIGNED, SCORE
from recentrifuge.krona import KronaTree
from recentrifuge.lmat import select_lmat_inputs
from recentrifuge.rank import Rank, TaxLevels
from recentrifuge.stats import SampleStats
from recentrifuge.taxclass import process_output
from recentrifuge.taxonomy import Taxonomy
from recentrifuge.trees import TaxTree, MultiTree, SampleDataById

import openpyxl
import pandas as pd


def _debug_dummy_plot(taxonomy: Taxonomy,
                      htmlfile: Filename,
                      scoring: Scoring = Scoring.SHEL,
                      ):
    """

    Generate dummy Krona plot via Krona 2.0 XML spec and exit

    """
    print(gray(f'Generating dummy Krona plot {htmlfile}...'), end='')
    sys.stdout.flush()
    samples: List[Sample] = [Sample('SINGLE'), ]
    krona: KronaTree = KronaTree(samples,
                                 min_score=Score(35),
                                 max_score=Score(100),
                                 scoring=scoring,
                                 )
    polytree: MultiTree = MultiTree(samples=samples)
    polytree.grow(ontology=taxonomy)
    polytree.toxml(ontology=taxonomy, krona=krona)
    krona.tohtml(htmlfile, pretty=True)
    print(green('OK!'))


def main():
    """Main entry point to Recentrifuge."""

    def vprint(*arguments) -> None:
        """Print only if verbose/debug mode is enabled"""
        if args.debug:
            print(*arguments)

    def configure_parser():
        """Argument Parser Configuration"""
        parser = argparse.ArgumentParser(
            description=('Robust comparative analysis and contamination '
                         'removal for metagenomics'),
            epilog=f'%(prog)s - Release {__version__} - {__date__}' + LICENSE,
            formatter_class=argparse.RawDescriptionHelpFormatter
        )
        parser.add_argument(
            '-V', '--version',
            action='version',
            version=f'%(prog)s version {__version__} released in {__date__}'
        )
        parser_in = parser.add_argument_group(
            'input', 'Define Recentrifuge input files and formats')
        parser_in.add_argument(
            '-n', '--nodespath',
            action='store',
            metavar='PATH',
            default=TAXDUMP_PATH,
            help=('path for the nodes information files '
                  '(nodes.dmp and names.dmp from NCBI)')
        )
        parser_in.add_argument(
            '--format',
            action='store',
            metavar='GENERIC_FORMAT',
            type=str,
            default=None,
            help=('format of the output files from a generic classifier '
                  'included with the option -g; It is a string like '
                  '"TYP:csv,TID:1,LEN:3,SCO:6,UNC:0" where valid file TYPes '
                  'are csv/tsv/ssv, and the rest of fields indicate the number'
                  ' of column used (starting in 1) for the TaxIDs assigned, '
                  'the LENgth of the read, the SCOre given to the assignment, '
                  'and the taxid code used for UNClassified reads')
        )
        parser_filein = parser_in.add_mutually_exclusive_group(required=True)
        parser_filein.add_argument(
            '-f', '--file',
            action='append',
            metavar='FILE',
            type=Filename,
            help=('Centrifuge output files; if a single directory is entered, '
                  'every .out file inside will be taken as a different sample;'
                  ' multiple -f is available to include several Centrifuge '
                  'samples')
        )
        parser_filein.add_argument(
            '-g', '--generic',
            action='append',
            metavar='FILE',
            type=Filename,
            help=('output file from a generic classifier; it requires the flag'
                  ' --format (see such option for details); if a single '
                  'directory is entered, every file inside will be taken as a '
                  'different sample; multiple -g is available to include '
                  'several generic samples by filename')
        )
        parser_filein.add_argument(
            '-l', '--lmat',
            action='append',
            metavar='FILE',
            type=Filename,
            default=None,
            help=('LMAT output dir or file prefix; if just "." is entered, '
                  'every subdirectory under the current directory will be '
                  'taken as a sample and scanned looking for LMAT output files'
                  '; multiple -l is available to include several samples')
        )
        parser_filein.add_argument(
            '-r', '--clark',
            action='append',
            metavar='FILE',
            type=Filename,
            help=('CLARK full-mode output files; if a single directory is '
                  'entered, every .csv file inside will be taken as a '
                  'different sample; multiple -r is available to include '
                  'several CLARK, CLARK-l, and CLARK-S full-mode samples')
        )
        parser_filein.add_argument(
            '-k', '--kraken',
            action='append',
            metavar='FILE',
            type=Filename,
            help=('Kraken output files; if a single directory is entered, '
                  'every .krk file inside will be taken as a different sample;'
                  ' multiple -k is available to include several Kraken '
                  '(version 1 or 2) samples by filename')
        )
        parser_out = parser.add_argument_group(
            'output', 'Related to the Recentrifuge output files')
        parser_out.add_argument(
            '-o', '--outprefix',
            action='store',
            metavar='FILE',
            type=Filename,
            help='output prefix; if not given, it will be inferred '
                 'from input files; an HTML filename is still accepted '
                 'for backwards compatibility with legacy --outhtml option'
        )
        parser_out.add_argument(
            '-e', '--extra',
            action='store',
            metavar='OUTPUT_TYPE',
            choices=[str(xtra) for xtra in Extra],
            default=str(Extra(0)),
            help=(f'type of extra output to be generated, and can be one of '
                  f'{[str(xtra) for xtra in Extra]}')
        )
        parser_out.add_argument(
            '-p', '--pickle',
            action='count',
            default=0,
            help=('pickle (serialize) statistics and data results in pandas '
                  'DataFrames (format affected by selection of --extra); one'
                  ' additional flag and the input samples will be serialized'
                  ' as a dict of sample names and TaxTree objects')
        )
        parser_out.add_argument(
            '--nohtml',
            action='store_true',
            help='suppress saving the HTML output file'
        )
        parser_coarse = parser.add_argument_group(
            'tuning',
            'Coarse tuning of algorithm parameters')
        parser_cross = parser_coarse.add_mutually_exclusive_group(
            required=False)
        parser_cross.add_argument(
            '-a', '--avoidcross',
            action='store_true',
            help='avoid cross analysis'
        )
        parser_cross.add_argument(
            '-c', '--controls',
            action='store',
            metavar='CONTROLS_NUMBER',
            type=int,
            default=0,
            help=('this number of first samples will be treated as negative '
                  'controls; default is no controls')
        )
        parser_coarse.add_argument(
            '-s', '--scoring',
            action='store',
            metavar='SCORING',
            choices=[str(each_score) for each_score in Scoring],
            default=str(Scoring(0)),
            help=(f'type of scoring to be applied, and can be one of '
                  f'{[str(scr) for scr in Scoring]}')
        )
        parser_coarse.add_argument(
            '-y', '--minscore',
            action='store',
            metavar='NUMBER',
            type=lambda txt: Score(float(txt)),
            default=None,
            help=('minimum score/confidence of the classification of a read '
                  'to pass the quality filter; all pass by default')
        )
        parser_coarse.add_argument(
            '-m', '--mintaxa',
            action='store',
            metavar='INT',
            type=int,
            default=None,
            help=('minimum taxa to avoid collapsing one level into the parent'
                  ' (if not specified a value will be automatically assigned)')
        )
        parser_coarse.add_argument(
            '-x', '--exclude',
            action='append',
            metavar='TAXID',
            type=Id,
            default=[],
            help=('NCBI taxid code to exclude a taxon and all underneath '
                  '(multiple -x is available to exclude several taxid)')
        )
        parser_coarse.add_argument(
            '-i', '--include',
            action='append',
            metavar='TAXID',
            type=Id,
            default=[],
            help=('NCBI taxid code to include a taxon and all underneath '
                  '(multiple -i is available to include several taxid); '
                  'by default, all the taxa are considered for inclusion')
        )
        parser_fine = parser.add_argument_group(
            'fine tuning',
            'Fine tuning of algorithm parameters')
        parser_fine.add_argument(
            '-z', '--ctrlminscore',
            action='store',
            metavar='NUMBER',
            type=lambda txt: Score(float(txt)),
            default=None,
            help=('minimum score/confidence of the classification of a read '
                  'in control samples to pass the quality filter; it defaults '
                  'to "minscore"')
        )
        parser_fine.add_argument(
            '-w', '--ctrlmintaxa',
            action='store',
            metavar='INT',
            type=int,
            default=None,
            help=('minimum taxa to avoid collapsing one level into the parent'
                  ' (if not specified a value will be automatically assigned)')
        )
        parser_fine.add_argument(
            '-u', '--summary',
            action='store',
            metavar='SUMMARY_BEHAVIOR',
            choices=[str(summ) for summ in Summary],
            default=str(Summary(0)),
            help=(f'choice for summary behaviour, and can be one of '
                  f'{[str(summ) for summ in Summary]}')
        )
        parser_fine.add_argument(
            '-t', '--takeoutroot',
            action='store_true',
            help='remove counts directly assigned to the "root" level'
        )
        parser_fine.add_argument(
            '--nokollapse',
            action='store_true',
            help='show the "cellular organisms" taxon'
        )
        parser_mode = parser.add_argument_group('advanced',
                                                'Advanced modes of running')
        parser_mode.add_argument(
            '--dummy',  # hidden flag: just generate a dummy plot for JS debug
            action='store_true',
            help=argparse.SUPPRESS
        )
        parser_mode.add_argument(
            '-d', '--debug',
            action='store_true',
            help='increase output verbosity and perform additional checks'
        )
        parser_mode.add_argument(
            '--no-strain',
            action='store_true',
            help=('disable strain level as the resolution limit for the robust'
                  ' contamination removal algorithm; by default, strain level'
                  ' is now enabled')
        )
        parser_mode.add_argument(
            '--sequential',
            action='store_true',
            help='deactivate parallel processing'
        )
        parser_mode.add_argument(
            '-T', '--threads',
            action='store',
            metavar='NUMBER',
            type=int,
            default=0,
            help=('number of threads to use for parallel processing; '
                  '0 (default) means legacy mode using min(cpu_count, samples)')
        )
        return parser

    def parse_generic() -> GenericFormat:
        """Check and parse generic format"""
        if args.format is None:
            print(red('ERROR!'), gray(
                'Format specifier (--format) is mandatory with output from a '
                'generic classifier.\n\tPlease run with --help for details.'))
            exit(2)
        return GenericFormat(args.format)

    def get_num_processes(num_items: int) -> int:
        """Calculate number of processes for multiprocessing Pool.
        
        Args:
            num_items: The number of items to process (samples, ranks, etc.)
            
        Returns:
            Number of processes to use in the Pool.
            
        If --threads is 0 (default/legacy mode), uses min(cpu_count, num_items).
        If --threads > 0, uses min(threads, num_items) to avoid idle processes.
        """
        if args.threads == 0:
            # Legacy behavior: use min of CPU count and number of items
            cpu_count = os.cpu_count() or 1
            return min(cpu_count, num_items)
        else:
            # User specified number of threads
            return min(args.threads, num_items)

    def check_advanced_modes():
        """Check advanced modes of running"""
        nonlocal target_ranks  # type: ignore  # mypy issue #7057

        # Check for debugging mode
        if args.debug:
            print(blue('INFO:'), gray('Debugging mode activated'))
            print(blue('INFO:'), gray('Active parameters:'))
            for key, val in vars(args).items():
                if val is not None and val is not False and val != []:
                    print(gray(f'\t{key} ='), f'{val}')
            if genfmt is not None:
                print(blue('INFO:'), gray(genfmt))
        # Check for sequential mode
        if args.sequential:
            print(blue('INFO:'), gray('Parallel mode deactivated'))
        # Check for threads option
        if args.threads > 0:
            print(blue('INFO:'), gray(f'Using {args.threads} thread(s) for parallel processing'))
        # Check for no-strain mode (strain is now default)
        if args.no_strain:
            print(gray('INFO:'), 'Strain level disabled')
            target_ranks = Rank.selected_ranks
        else:
            target_ranks = Rank.selected_ranks_ext

    def select_inputs():
        """Choose right classifier, scoring, input and output files"""
        nonlocal process, scoring, input_files  # type: ignore  # mypy #7057
        nonlocal plasmidfile, classifier  # type: ignore  # mypy issue #7057

        if generics:
            classifier = Classifier.GENERIC
            process = process_output
            if scoring not in [Scoring.GENERIC, Scoring.LENGTH,
                               Scoring.LOGLENGTH, Scoring.NORMA]:
                print(yellow('WARNING!'), 'Scoring scheme not supported for '
                                          'generic classifier; using GENERIC.')
                scoring = Scoring.GENERIC
            input_files = generics
            if len(generics) == 1 and os.path.isdir(generics[0]):
                select_generic_inputs(generics)
        elif krakens:
            classifier = Classifier.KRAKEN
            process = process_output
            if scoring not in [Scoring.SHEL, Scoring.LENGTH, Scoring.LOGLENGTH,
                               Scoring.NORMA, Scoring.KRAKEN]:
                print(yellow('WARNING!'), 'Scoring scheme not supported for '
                                          'Kraken classifier; using KRAKEN.')
                scoring = Scoring.KRAKEN
            input_files = krakens
            if len(krakens) == 1 and os.path.isdir(krakens[0]):
                select_kraken_inputs(krakens)
        elif clarks:
            classifier = Classifier.CLARK
            process = process_output
            if scoring in [Scoring.LMAT, Scoring.KRAKEN]:
                print(yellow('WARNING!'), 'Scoring scheme not supported for '
                                          'CLARK classifier; using SHEL.')
                scoring = Scoring.SHEL
            input_files = clarks
            if len(clarks) == 1 and os.path.isdir(clarks[0]):
                select_clark_inputs(clarks)
        elif lmats:
            classifier = Classifier.LMAT
            process = process_output
            if scoring is not Scoring.LMAT:
                print(yellow('WARNING!'), 'Scoring scheme not supported for '
                                          'LMAT classifier; using LMAT.')
                scoring = Scoring.LMAT
            input_files = lmats
            plasmidfile = Filename(os.path.join(args.nodespath, PLASMID_FILE))
            select_lmat_inputs(lmats)
        elif outputs:
            classifier = Classifier.CENTRIFUGE
            process = process_output
            if scoring not in [Scoring.SHEL, Scoring.LENGTH, Scoring.LOGLENGTH,
                               Scoring.NORMA]:
                print(yellow('WARNING!'), 'Scoring scheme not supported for '
                                          'Centrifuge classifier; using SHEL.')
                scoring = Scoring.SHEL
            input_files = outputs
            if len(outputs) == 1 and os.path.isdir(outputs[0]):
                select_centrifuge_inputs(outputs)
        else:
            raise Exception(red('\nERROR!'), 'Unknown classifier!')

    def check_controls():
        """Check and info about the control samples"""
        if args.controls:
            if args.controls > len(input_files):
                print(red(' ERROR!'), gray('More controls than samples'))
                exit(1)
            print(gray('Control(s) sample(s) for subtractions:'))
            for i in range(args.controls):
                print(blue(f'\t{input_files[i]}'))

    def infer_rcf_outputs():
        """Infer output prefix and HTML filename from inputs"""
        nonlocal out_prefix, htmlfile  # type: ignore  # mypy #7057
        if lmats:  # Select case for dir name or filename prefix
            if os.path.isdir(lmats[0]):  # Dir name
                dirname = os.path.dirname(os.path.normpath(lmats[0]))
                if not dirname or dirname == '.':
                    basename = 'output'
                else:
                    basename = os.path.basename(dirname)
            else:  # Explicit path and file name prefix is provided
                dirname, basename = os.path.split(lmats[0])
            out_prefix = Filename(os.path.join(dirname, basename))
            htmlfile = Filename(out_prefix + HTML_SUFFIX)
        else:
            out_prefix = Filename(input_files[0].split('_mhl')[0])
            htmlfile = Filename(out_prefix + HTML_SUFFIX)

    def read_samples():
        """Read samples"""
        print(gray('\nPlease, wait, processing files in parallel...\n'))
        void_sample_cnt: int = 0
        # Enable parallelization with 'spawn' under known platforms
        if platform.system() and not args.sequential:  # Only for known systems
            mpctx = mp.get_context('fork')
            with mpctx.Pool(processes=get_num_processes(
                                          len(input_files))) as pool:
                async_results = [pool.apply_async(
                    process,
                    args=[input_files[num],  # file name
                          True if num < args.controls else False],  # is ctrl?
                    kwds=kwargs
                ) for num in range(len(input_files))]
                for file, (sample, tree, out, stat, err) in zip(
                        input_files, [r.get() for r in async_results]):
                    if err is Err.NO_ERROR:
                        samples.append(sample)
                        trees[sample] = tree
                        taxids[sample] = out.get_taxlevels()
                        if out.counts is not None:
                            counts[sample] = out.counts
                        if out.accs is not None:
                            accs[sample] = out.accs
                        if out.scores is not None:
                            scores[sample] = out.scores
                        stats[sample] = stat
                        mintaxas[sample] = stat.mintaxa
                    elif err is Err.VOID_SAMPLE:
                        void_sample_cnt += 1
                    elif err is Err.VOID_CTRL:
                        print('There were void controls.', red('Aborting!'))
                        exit(1)
        else:  # sequential processing of each sample
            for num, file in enumerate(input_files):
                (sample, tree, out, stat, err) = process(
                    file, True if num < args.controls else False, **kwargs)
                if err is Err.NO_ERROR:
                    samples.append(sample)
                    trees[sample] = tree
                    taxids[sample] = out.get_taxlevels()
                    if out.counts is not None:
                        counts[sample] = out.counts
                    if out.accs is not None:
                        accs[sample] = out.accs
                    if out.scores is not None:
                        scores[sample] = out.scores
                    stats[sample] = stat
                    mintaxas[sample] = stat.mintaxa
                elif err is Err.VOID_SAMPLE:
                    void_sample_cnt += 1
                elif err is Err.VOID_CTRL:
                    print('There were void controls.', red('Aborting!'))
                    exit(1)
        raw_samples.extend(samples)  # Store raw sample names
        if not raw_samples:
            print('All the samples are void.', red('Aborting!'))
            exit(2)
        elif void_sample_cnt:
            print(yellow('Warning! '), void_sample_cnt, ' of the ',
                  len(input_files), ' input samples are ', yellow('VOID!'))

    def analyze_samples():
        """Cross analysis of samples in parallel by taxlevel"""
        print(gray('Please, wait. Performing cross analysis in parallel...\n'))
        if platform.system() and not args.sequential:  # Only for known systems
            mpctx = mp.get_context('fork')  # Important for OSX&Win
            with mpctx.Pool(processes=get_num_processes(len(
                    target_ranks))) as pool:
                async_results = [pool.apply_async(
                    process_rank,
                    args=[level],
                    kwds=kwargs
                ) for level in target_ranks]
                for level, (smpls, abunds, accumulators, score) in zip(
                        target_ranks,
                        [r.get() for r in async_results]):
                    samples.extend(smpls)
                    counts.update(abunds)
                    accs.update(accumulators)
                    scores.update(score)
        else:  # sequential processing of each selected rank
            for level in target_ranks:
                (smpls, abunds,
                 accumulators, score) = process_rank(level, **kwargs)
                samples.extend(smpls)
                counts.update(abunds)
                accs.update(accumulators)
                scores.update(score)

    def summarize_samples():
        """Summary of samples in parallel by type of cross-analysis"""
        # timing initialization
        summ_start_time: float = time.perf_counter()
        print(gray('Please, wait. Generating summaries in parallel...'))
        # Update kwargs with more parameters for the followings func calls
        kwargs.update({'samples': samples})
        # Get list of set of samples to summarize (note pylint bug #776)
        # pylint: disable=unsubscriptable-object
        target_analysis: col.OrderedDict[str, None] = col.OrderedDict({
            f'{raw}_{study}': None for study in [STR_EXCLUSIVE, STR_CONTROL]
            for raw in raw_samples
            for smpl in samples if smpl.startswith(f'{raw}_{study}')
        })
        # pylint: enable=unsubscriptable-object
        # Add shared and control_shared analysis if they exist (are not void)
        for study in [STR_SHARED, STR_CONTROL_SHARED]:
            for smpl in samples:
                if smpl.startswith(study):
                    target_analysis[study] = None
                    break

        if platform.system() and not args.sequential:  # Only for known systems
            mpctx = mp.get_context('fork')
            with mpctx.Pool(processes=get_num_processes(
                                          len(target_analysis))) as pool:
                async_results = [pool.apply_async(
                    summarize_analysis,
                    args=[analysis],
                    kwds=kwargs
                ) for analysis in target_analysis]
                for analysis, (summary, abund, acc, score) in zip(
                        target_analysis, [r.get() for r in async_results]):
                    if summary:  # Avoid adding empty samples
                        summaries.append(summary)
                        counts[summary] = abund
                        accs[summary] = acc
                        scores[summary] = score
        else:  # sequential processing of each selected rank
            for analysis in target_analysis:
                (summary, abund,
                 acc, score) = summarize_analysis(analysis, **kwargs)
                if summary is not None:  # Avoid adding empty samples
                    summaries.append(summary)
                    counts[summary] = abund
                    accs[summary] = acc
                    scores[summary] = score
        # Timing results
        print(gray('Summary elapsed time:'),
              f'{time.perf_counter() - summ_start_time:.3g}', gray('sec'))

    def generate_krona(html_out: Filename):
        """Generate Krona plot with all the results via Krona 2.0 XML spec"""

        print(gray('\nBuilding the taxonomy multiple tree... '), end='')
        sys.stdout.flush()
        krona: KronaTree = KronaTree(samples,
                                     num_raw_samples=len(raw_samples),
                                     stats=stats,
                                     min_score=Score(
                                         min([min(scores[sample].values())
                                              for sample in samples
                                              if len(scores[sample])])),
                                     max_score=Score(
                                         max([max(scores[sample].values())
                                              for sample in samples
                                              if len(scores[sample])])),
                                     scoring=scoring,
                                     )
        polytree.grow(ontology=ncbi,
                      abundances=counts,
                      accs=accs,
                      scores=scores)
        print(green('OK!'))
        if not args.nohtml:
            print(gray('Generating interactive plot (') + magenta(html_out) +
                  gray(')... '), end='')
            sys.stdout.flush()
            polytree.toxml(ontology=ncbi, krona=krona)
            krona.tohtml(html_out, pretty=False)
            print(green('OK!'))

    def save_extra_output(html_out: Filename):
        """Save extra output with results via pandas DataFrame"""

        # Initial check and info
        if extra is Extra.FULL or extra is Extra.DYNOMICS:
            vprint(blue('INFO:'),
                   gray('Saving extra output as an Excel file.'))
        elif extra is Extra.CSV:
            vprint(blue('INFO:'),
                   gray('Saving extra output as CSV files.'))
        elif extra is Extra.TSV:
            vprint(blue('INFO:'),
                   gray('Saving extra output as TSV files.'))
        elif extra is Extra.MULTICSV:
            vprint(blue('INFO:'),
                   gray('Saving extra output as multiple CSV files.'))
        else:
            raise Exception(f'ERROR! Unknown Extra option "{extra}"')

        # Setup of extra file name
        xlsxwriter: pd.ExcelWriter | None = None
        sv_base: Filename = Filename(html_out.split('.html')[0])
        sv_ext: Filename = Filename('')  # Will be set below for CSV/TSV modes
        sv_sep: str = ','  # Default CSV separator
        data_frame: pd.DataFrame | None = None  # Initialize for pickle section
        if extra is Extra.FULL or extra is Extra.DYNOMICS:
            xlsx_name: Filename = Filename(html_out.split('.html')[0] +
                                           '.xlsx')
            print(gray(f'Generating Excel {str(extra).lower()} summary (') +
                  magenta(xlsx_name) + gray(')... '), end='')
            sys.stdout.flush()
            xlsxwriter = pd.ExcelWriter(xlsx_name)
        elif extra in [Extra.CSV, Extra.TSV, Extra.MULTICSV]:
            if extra in [Extra.CSV, Extra.MULTICSV]:
                sv_ext = Filename('csv')
                sv_sep = ','
            elif extra is Extra.TSV:
                sv_ext = Filename('tsv')
                sv_sep = '\t'
            print(gray(f'Generating {str(extra).lower()} extra output (') +
                  magenta('[' + sv_base + '.]*.' + sv_ext)
                  + gray(')... '), end='')
            sys.stdout.flush()
        else:
            raise Exception(f'ERROR! Unknown Extra option "{extra}"')
        odict_rows: Dict[Id, List] = col.OrderedDict()

        # Save basic statistics of raw samples
        stat_frame = pd.DataFrame.from_records(
            {raw: stats[raw].to_odict() for raw in raw_samples},
            index=list(stats[raw_samples[0]].to_odict().keys()))
        if extra is Extra.FULL or extra is Extra.DYNOMICS:
            stat_frame.to_excel(xlsxwriter, sheet_name=STATS_SHEET_NAME)
        elif extra in [Extra.CSV, Extra.TSV, Extra.MULTICSV]:
            stat_frame.to_csv(Filename(sv_base + '.stat.' + sv_ext), sep=sv_sep)

        # Save taxid related statistics per sample
        if extra in [Extra.FULL, Extra.CSV, Extra.TSV, Extra.MULTICSV]:
            polytree.to_odict(ontology=ncbi, odict=odict_rows)
            # Generate the pandas DataFrame from items and export to Extra
            iterable_1 = [samples, [COUNT, UNASSIGNED, SCORE]]
            cols1 = pd.MultiIndex.from_product(iterable_1,
                                               names=['Samples', 'Stats'])
            iterable_2 = [['Details'], ['Rank', 'Name']]
            cols2 = pd.MultiIndex.from_product(iterable_2)
            cols = cols1.append(cols2)
            data_frame = pd.DataFrame.from_dict(
                odict_rows, orient='index', columns=cols)
            data_frame.index.names = ['Id']
            if extra in [Extra.CSV, Extra.TSV]:
                data_frame.to_csv(Filename(sv_base + '.data.' + sv_ext), sep=sv_sep)
            elif extra is Extra.MULTICSV:
                for sample in samples:  # Only save taxa with counts per sample
                    sub_df = data_frame[
                        data_frame[sample, 'count'] > 0][[sample, 'Details']]
                    csv_name: Filename
                    # Shared samples need help (sv_base) to get the path right
                    if sample.startswith(STR_SHARED):
                        csv_name = Filename(sv_base + '_' + sample
                                            + '.data.' + sv_ext)
                    else:
                        csv_name = Filename(sample + '.data.' + sv_ext)
                    sub_df.to_csv(csv_name, sep=sv_sep)
            else:
                data_frame.to_excel(xlsxwriter, sheet_name=str(extra))
        elif extra is Extra.DYNOMICS:
            dyn_target_ranks: List[Rank] = [Rank.NO_RANK]
            if args.controls:  # if controls, add specific ranks as targets
                if args.no_strain:
                    dyn_target_ranks.extend(Rank.selected_ranks)
                else:
                    dyn_target_ranks.extend(Rank.selected_ranks_ext)
            for rank in dyn_target_ranks:  # Once for no rank dependency (NO_RANK)
                indexes: List[int] = []  # Initialize as empty list
                sheet_name: str = 'RCF sheet'
                columns: List[str] = ['__Rank', '__Name']
                if args.controls:
                    indexes = [i for i in range(len(raw_samples), len(samples))
                               # Check if sample ends in _(STR_CONTROL)_(rank)
                               if (STR_CONTROL in samples[i].split('_')[-2:]
                                   and rank.name.lower() in
                                   samples[i].split('_')[-1:])]
                    sheet_name = f'{STR_CONTROL}_{rank.name.lower()}'
                    columns.extend([samples[i].replace(
                        '_' + STR_CONTROL + '_' + rank.name.lower(), '')
                        for i in indexes])
                if rank is Rank.NO_RANK:  # No rank dependency
                    indexes = list(range(len(raw_samples)))
                    sheet_name = f'raw_samples_{rank.name.lower()}'
                    columns.extend(raw_samples)
                polytree.to_odict(ontology=ncbi, odict=odict_rows,
                                  cmplxcruncher=True,
                                  sample_indexes=indexes)
                data_frame = pd.DataFrame.from_dict(odict_rows,
                                                    orient='index',
                                                    columns=columns)
                data_frame.index.names = ['Id']
                data_frame.to_excel(xlsxwriter, sheet_name=sheet_name)
        else:
            raise Exception(f'ERROR! Unknown Extra option "{extra}"')
        if xlsxwriter is not None:
            xlsxwriter.close()
        print(green('OK!'))

        # Serialize dataframes in compressed pickle files
        if args.pickle > 0:
            # Serialize dataframe with basic statistics of raw samples
            vprint(blue('INFO:'),
                   gray('Serializing dataframes with stats and data.'))
            pkl_file: Filename = Filename(sv_base + '.stat' + PKL_SUFFIX)
            print(gray(f'Generating output file ') +
                  magenta(pkl_file) + gray(')... '), end='')
            sys.stdout.flush()
            stat_frame.to_pickle(pkl_file, compression='infer', protocol=-1)
            print(green('OK!'))
            # Serialize dataframe with taxid related statistics per sample
            pkl_file = Filename(sv_base + '.data' + PKL_SUFFIX)
            print(gray(f'Generating output file ') +
                  magenta(pkl_file) + gray(')... '), end='')
            sys.stdout.flush()
            if data_frame is not None:
                data_frame.to_pickle(pkl_file, compression='infer', protocol=-1)
            print(green('OK!'))
            if args.pickle > 1:  # Serializing 'samples' as well
                vprint(blue('INFO:'),
                       gray('Serializing input sample\'s dict of TaxTree(s).'))
                pkl_file = Filename(sv_base + '.samples' + PKL_SUFFIX)
                print(gray(f'Generating output file ') +
                      magenta(pkl_file) + gray(')... '), end='')
                sys.stdout.flush()
                with bz2.BZ2File(pkl_file, 'wb') as fbz2:
                    pickle.dump(trees, fbz2, pickle.HIGHEST_PROTOCOL)
                print(green('OK!'))

    # timing initialization
    start_time: float = time.time()
    # Program header
    print(f'\n=-= {sys.argv[0]} =-= v{__version__} - {__date__}'
          f' =-= by {__author__} =-=\n')
    sys.stdout.flush()

    # Parse arguments
    argparser = configure_parser()
    args: argparse.Namespace = argparser.parse_args()
    clarks: List[Filename] = args.clark
    generics: List[Filename] = args.generic
    krakens: List[Filename] = args.kraken
    lmats: List[Filename] = args.lmat
    outputs: List[Filename] = args.file
    input_files: List[Filename]
    nodesfile: Filename = Filename(os.path.join(args.nodespath, NODES_FILE))
    namesfile: Filename = Filename(os.path.join(args.nodespath, NAMES_FILE))
    out_prefix: Filename = Filename(args.outprefix)
    collapse: bool = not args.nokollapse
    excluding: Set[Id] = set(args.exclude)
    including: Set[Id] = set(args.include)
    # parse those arguments related to enumerations
    scoring: Scoring = Scoring[args.scoring]
    extra: Extra = Extra[args.extra]
    summary: Summary = Summary[args.summary]

    # Target ranks are currently hardwired in the Rank class
    target_ranks: List[Rank] = Rank.selected_ranks

    # Check and parse generic format, if selected
    genfmt: GenericFormat | None = None
    if args.generic:
        genfmt = parse_generic()

    # Check arguments related to advanced modes of running
    check_advanced_modes()

    # Choose right classifier, scoring, control and regular input files
    plasmidfile: Filename | None = None
    classifier: Classifier | None = None
    process: Callable[..., Tuple[Sample, TaxTree, SampleDataById,
                                 SampleStats, Err]]
    select_inputs()
    check_controls()

    # Select the right output files (output prefix and HTML filename)
    htmlfile: Filename | None = None
    if out_prefix is None:  # Infer from inputs
        infer_rcf_outputs()
    elif out_prefix.endswith('.html') or out_prefix.endswith('.htm'):
        htmlfile = out_prefix  # Support legacy --outhtml option
    else:
        htmlfile = Filename(out_prefix + HTML_SUFFIX)

    # Ensure htmlfile is set (all code paths above should set it)
    if htmlfile is None:
        print(red('ERROR!'), gray('Could not determine output HTML filename'))
        exit(1)

    # Load NCBI nodes, names and build children
    ncbi: Taxonomy = Taxonomy(nodesfile, namesfile, plasmidfile,
                              collapse, excluding, including, args.debug)

    # If dummy flag enabled, just create dummy krona and exit
    if args.dummy:
        _debug_dummy_plot(ncbi, htmlfile, scoring)
        exit(0)

    # Declare variables of known data before the core analysis
    samples: List[Sample] = []
    raw_samples: List[Sample] = []

    # Declare variables that will hold results for the samples analyzed
    trees: Dict[Sample, TaxTree] = {}
    counts: Dict[Sample, Any] = {}  # UnionCounter from SampleDataById
    accs: Dict[Sample, Counter[Id]] = {}
    taxids: Dict[Sample, TaxLevels] = {}
    scores: Dict[Sample, Any] = {}  # UnionScores from SampleDataById
    stats: Dict[Sample, SampleStats] = {}
    mintaxas: Dict[Sample, int] = {}

    # Define dictionary of parameters for methods to be called (to be extended)
    kwargs = {'controls': args.controls,
              'ontology': ncbi,
              'classifier': classifier,
              'genfmt': genfmt,
              'scoring': scoring,
              'minscore': args.minscore,
              'ctrlminscore': (
                  args.ctrlminscore
                  if args.ctrlminscore is not None else args.minscore),
              'mintaxa': args.mintaxa,
              'ctrlmintaxa': args.ctrlmintaxa,
              'debug': args.debug, 'root': args.takeoutroot,
              }
    # Read the samples in parallel
    read_samples()
    # Update kwargs with more parameters for the followings func calls
    kwargs.update({'taxids': taxids, 'counts': counts, 'scores': scores,
                   'accs': accs, 'raw_samples': raw_samples,
                   'mintaxas': mintaxas})
    kwargs.pop('mintaxa')  # mintaxa info is included in mintaxas dict
    kwargs.pop('ctrlmintaxa')  # ctrlmintaxa info is included in mintaxas dict

    # Avoid cross analysis if just one report file or explicitly stated by flag
    if len(raw_samples) > 1 and not args.avoidcross:
        analyze_samples()
        if summary is not Summary.AVOID:
            summaries: List[Sample] = []
            summarize_samples()
            if summary is Summary.ONLY:
                samples = raw_samples + summaries
            else:
                samples.extend(summaries)

    # Final result generation is done in sequential mode
    polytree: MultiTree = MultiTree(samples=samples)
    generate_krona(htmlfile)
    save_extra_output(htmlfile)

    # Timing results
    print(gray('Total elapsed time:'), time.strftime(
        "%H:%M:%S", time.gmtime(time.time() - start_time)))


if __name__ == '__main__':
    main()
