#!python

import os.path as osp
import psutil
import threading
import time


class PeakMemoryProfiler:
    """
    A context manager that monitors and tracks the peak memory usage of a process
    (and optionally its children) over a period of time. The memory usage can be
    reported in various units (bytes, MB, or GB).

    Example:

    ```
    with PeakMemoryProfiler() as profiler:
        # Code block to monitor memory usage
        ...
    ```

    Class Attributes:
    :ivar pid: The PID of the process being monitored. Defaults to the current process.
    :ivar interval: Time interval (in seconds) between memory checks. Defaults to 0.1.
    :ivar include_children: Whether memory usage from child processes is included. Defaults to True.
    :ivar unit: The unit used to report memory usage (either 'bytes', 'MB', or 'GB'). Defaults to 'MB'.
    :ivar max_memory: The peak memory usage observed during the monitoring period.
    :ivar monitoring_thread: Thread used for monitoring memory usage.
    :ivar _stop_monitoring: Event used to signal when to stop monitoring.
    """

    def __init__(self, pid=None, interval=0.1, include_children=True, unit="MB"):
        """
        Initializes the PeakMemoryProfiler instance with the provided parameters.

        :param pid: The PID of the process to monitor. Defaults to None (current process).
        :param interval: The interval (in seconds) between memory checks. Defaults to 0.1.
        :param include_children: Whether to include memory usage from child processes. Defaults to True.
        :param unit: The unit in which to report memory usage. Options are 'bytes', 'MB', or 'GB'. Defaults to 'MB'.
        """
        self.pid = pid or psutil.Process().pid  # Default to current process if no PID is provided
        self.interval = interval
        self.include_children = include_children
        self.unit = unit
        self.max_memory = 0
        self.monitoring_thread = None
        self._stop_monitoring = threading.Event()

    def __enter__(self):
        """
        Starts monitoring memory usage when entering the context block.

        :return: Returns the instance of PeakMemoryProfiler, so that we can access peak memory later.
        """
        self.process = psutil.Process(self.pid)
        self.max_memory = 0
        self._stop_monitoring.clear()  # Clear the stop flag to begin monitoring
        self.monitoring_thread = threading.Thread(target=self._monitor_memory)
        self.monitoring_thread.start()
        return self  # Return the instance so that the caller can access max_memory

    def __exit__(self, exc_type, exc_value, traceback):
        """
        Stops the memory monitoring when exiting the context block.

        :param exc_type: The exception type if an exception was raised in the block.
        :param exc_value: The exception instance if an exception was raised.
        :param traceback: The traceback object if an exception was raised.
        """
        self._stop_monitoring.set()  # Signal the thread to stop monitoring
        self.monitoring_thread.join()  # Wait for the monitoring thread to finish

    def get_curr_memory(self):
        """
        Get the current memory usage of the monitored process and its children.

        :return: The current memory usage in the specified unit (bytes, MB, or GB).
        :rtype: float
        """

        memory = self.process.memory_info().rss

        if self.include_children:
            # Include memory usage of child processes recursively
            for child in self.process.children(recursive=True):
                try:
                    memory += child.memory_info().rss
                except (psutil.NoSuchProcess, psutil.AccessDenied):
                    continue

        if self.unit == "MB":
            return memory / (1024 ** 2)  # Convert to MB
        elif self.unit == "GB":
            return memory / (1024 ** 3)  # Convert to GB
        else:
            return memory  # Return in bytes if no conversion is requested

    def _monitor_memory(self):
        """
        Monitors the memory usage of the process and its children continuously
        until the monitoring is stopped.

        This method runs in a separate thread and updates the peak memory usage
        as long as the monitoring flag is not set.
        """
        while not self._stop_monitoring.is_set():
            try:
                curr_memory = self.get_curr_memory()

                # Update max memory if a new peak is found
                self.max_memory = max(self.max_memory, curr_memory)
                time.sleep(self.interval)
            except psutil.NoSuchProcess:
                break  # Process no longer exists, stop monitoring

    def get_peak_memory(self):
        """
        Get the peak memory usage observed during the monitoring period.

        :return: The peak memory usage in the specified unit (bytes, MB, or GB).
        :rtype: float
        """
        return self.max_memory


def check_args(args):
    """
    Check the validity, consistency, and completeness of the commandline arguments.
    """

    from magenpy.utils.system_utils import get_filenames, is_path_writable

    # ----------------------------------------------------------
    # Check that the inputs data files/directories are valid:
    # TODO: Update this check once LD matrices are on AWS s3
    ld_store_files = get_filenames(args.ld_dir, extension='.zgroup')
    if len(ld_store_files) < 1:
        raise FileNotFoundError(f"No valid LD matrix files were found at: {args.ld_dir}")

    sumstats_files = get_filenames(args.sumstats_path)
    if len(sumstats_files) < 1:
        raise FileNotFoundError(f"No valid summary statistics files were found at: {args.sumstats_path}")

    # Check that the output directory is valid / writable:
    if not is_path_writable(args.output_dir):
        raise PermissionError(f"Output directory ({args.output_dir}) is not writable.")

    # Check that the temporary directory is valid / writable:
    if not is_path_writable(args.temp_dir):
        raise PermissionError(f"Temporary directory ({args.temp_dir}) is not writable.")

    # Check that lambda_min is non-negative (if provided):
    if args.lambda_min is not None:
        if args.lambda_min != 'infer':
            try:
                lm_min = float(args.lambda_min)
                if lm_min < 0.:
                    raise ValueError
            except ValueError:
                "The lambda_min parameter must be set to 'infer' or non-negative number."

    if args.hyp_search == 'GS':

        if args.grid_metric in ['pseudo_validation', 'pseudo_pearson_validation']:

            if args.validation_ld_panel is not None and args.validation_sumstats_path is not None:
                ld_store_files = get_filenames(args.validation_ld_panel, extension='.zgroup')
                if len(ld_store_files) < 1:
                    raise FileNotFoundError(f"No valid LD matrix files for the "
                                            f"validation set were found at: {args.ld_dir}")
                sumstats_files = get_filenames(args.validation_sumstats_path)
                if len(sumstats_files) < 1:
                    raise FileNotFoundError(f"No valid summary statistics files for the validation set "
                                            f"were found at: {args.validation_sumstats_path}")
            else:
                raise ValueError("To perform pseudo-validation, you need to provide "
                                 "summary statistics for the validation set.")


def init_data(args, verbose=True):
    """
    Initialize the data loaders and parsers for the GWAS summary statistics.
    This function takes as input an argparse.Namespace object (the output of
    `argparse.ArgumentParser.parse_args`) and returns a list of dictionaries
    containing the GWADataLoaders for the training (and validation) datasets.

    :param args: An argparse.Namespace object containing the parsed commandline arguments.
    :param verbose: Whether to print status messages to the console.
    :return: A list of dictionaries containing the GWADataLoaders for the training (and validation) datasets.
    Each dictionary contains the following keys:
    - 'train': The GWADataLoader for the training dataset
    - 'valid': The GWADataLoader for the validation dataset (if provided)
    """

    import numpy as np
    from magenpy.GWADataLoader import GWADataLoader
    from magenpy.parsers.sumstats_parsers import SumstatsParser

    if verbose:
        print('\n{:-^62}\n'.format('  Reading & harmonizing input data  '))

    # Prepare the summary statistics parsers:
    if args.sumstats_format == 'custom':

        ss_parser = SumstatsParser(col_name_converter=args.custom_sumstats_mapper,
                                   sep=args.custom_sumstats_sep)
        ss_format = None
    else:
        ss_format = args.sumstats_format
        ss_parser = None

    if verbose:
        print(">>> Reading the training dataset...")

    # Construct a GWADataLoader object using LD + summary statistics:
    gdl = GWADataLoader(ld_store_files=args.ld_dir,
                        temp_dir=args.temp_dir,
                        # verbose=verbose,
                        threads=args.threads)

    # Unless the user explicitly decides to keep them, filter long-range LD regions:
    if args.exclude_lrld:
        print("> Filtering long-range LD regions...")
        for ld in gdl.ld.values():
            ld.filter_long_range_ld_regions()

    # Read the summary statistics file(s):
    gdl.read_summary_statistics(args.sumstats_path,
                                sumstats_format=ss_format,
                                parser=ss_parser)
    # Harmonize the data:
    gdl.harmonize_data()

    # If overall GWAS sample size is provided, set it here:
    if args.gwas_sample_size is not None:
        for ss in gdl.sumstats_table.values():
            ss.set_sample_size(args.gwas_sample_size)

    if args.genomewide:
        data_loaders = {'All': {'train': gdl}}
    else:
        # If we are not performing inference genome-wide,
        # then split the GWADataLoader object into multiple loaders,
        # one per chromosome.
        data_loaders = {c: {'train': g} for c, g in gdl.split_by_chromosome().items()}

    # ----------------------------------------------------------

    if args.hyp_search == 'GS':

        if verbose:
            print(">>> Reading the validation dataset...")

        extract_snps = np.concatenate(list(gdl.snps.values()))

        if args.validation_bed is not None:

            validation_gdl = GWADataLoader(
                bed_files=args.validation_bed,
                keep_file=args.validation_keep,
                phenotype_file=args.validation_pheno,
                extract_snps=extract_snps,
                backend=args.backend,
                temp_dir=args.temp_dir,
                threads=args.threads
            )

            if args.grid_metric in ['pseudo_validation', 'pseudo_pearson_validation']:
                validation_gdl.perform_gwas()
        else:

            # Construct the validation GWADataLoader object using LD + summary statistics:
            validation_gdl = GWADataLoader(ld_store_files=args.validation_ld_panel,
                                           temp_dir=args.temp_dir,
                                        #    verbose=verbose,
                                           threads=args.threads)

            # Prepare the validation summary statistics parsers:
            if args.validation_sumstats_format == 'custom':
                ss_parser = SumstatsParser(col_name_converter=args.validation_custom_sumstats_mapper,
                                           sep=args.validation_custom_sumstats_sep)
                ss_format = None
            else:
                ss_format = args.sumstats_format
                ss_parser = None

            # Read the summary statistics file(s):
            validation_gdl.read_summary_statistics(args.validation_sumstats_path,
                                                   sumstats_format=ss_format,
                                                   parser=ss_parser)

            # If overall GWAS sample size is provided, set it here:
            if args.validation_gwas_sample_size is not None:
                for ss in validation_gdl.sumstats_table.values():
                    ss.set_sample_size(args.validation_gwas_sample_size)

        # Filter SNPs:
        validation_gdl.filter_snps(extract_snps)
        # Harmonize the data:
        validation_gdl.harmonize_data()

        if args.genomewide:
            data_loaders['All']['valid'] = validation_gdl
        else:
            for c, g in validation_gdl.split_by_chromosome().items():
                data_loaders[c]['valid'] = g

    return data_loaders.values()


def prepare_model(args, verbose=True):

    from functools import partial

    from penprs.model.SSL import SSL
    from penprs.model.SSLAlpha import SSLAlpha
    from penprs.model.Lasso import Lasso

    if verbose:
        print('\n{:-^62}\n'.format('  Model details  '))

        print("- Model:", args.model)
        hyp_map = {'GS': 'Grid search', 'WS': 'Warm Start'}
        print("- Hyperparameter tuning strategy:", hyp_map[args.hyp_search])
        if args.hyp_search == 'GS':
            print("- Model selection criterion:", args.grid_metric)

    if args.lambda_min == 'infer' or args.lambda_min is None:
        lambda_min = 'infer'
    else:
        lambda_min = float(args.lambda_min)

    fix_params = None

    if args.fix_var is not None:
        fix_params = {'var': args.fix_var}

    models = {
        'SSL': SSL,
        'SSLAlpha': SSLAlpha,
        'LASSO': Lasso
    }

    if args.model == "SSL":
        p_model = partial(models[args.model],
                    lambda_min=lambda_min,
                    low_memory=not args.use_symmetric_ld,
                    dequantize_on_the_fly=args.dequantize_on_the_fly,
                    float_precision=args.float_precision,
                    threads=args.threads,
                    fix_params = fix_params,
                    unknown_var = args.unknown_var)
    elif args.model == "SSLAlpha":  
        p_model = partial(models[args.model],
                alpha=args.alpha,                
                lambda_min=lambda_min,
                low_memory=not args.use_symmetric_ld,
                dequantize_on_the_fly=args.dequantize_on_the_fly,
                float_precision=args.float_precision,
                threads=args.threads,
                fix_params = fix_params,
                unknown_var = args.unknown_var)
    else:
        p_model = partial(models[args.model],
                    lambda_min=lambda_min,
                    low_memory=not args.use_symmetric_ld,
                    dequantize_on_the_fly=args.dequantize_on_the_fly,
                    float_precision=args.float_precision,
                    threads=args.threads,
                    fix_params = fix_params)

    return p_model


def fit_model(model, data_dict, args):

    import time
    import numpy as np
    from penprs.utils.exceptions import OptimizationDivergence

    # Set the random seed:
    np.random.seed(args.seed)

    chromosome = data_dict['train'].chromosomes
    if len(chromosome) == 1:
        chromosome = chromosome[0]

    result_dict = {
        'ProfilerMetrics': {},
        'Chromosome': chromosome
    }

    for d in data_dict.values():
        d.verbose = False

    # ----------------------------------------------------------

    # Initialize the model:
    load_start_time = time.time()
    m = model(data_dict['train'])
    load_end_time = time.time()

    result_dict['ProfilerMetrics']['Load_time'] = round(load_end_time - load_start_time, 2)

    # ----------------------------------------------------------
    # Perform model fit

    fit_start_time = time.time()

    lambda_args = {
        'min_lambda_frac': args.min_lambda_frac,
        'max_lambda_frac': args.max_lambda_frac
    }

    for _ in range(3):  # 3 attempts
        try:

            if args.cold_start:
                m.fit(theta_0=lambda_args)
            elif args.model in ['SSL', 'SSLAlpha']:
                m.warm_start_fit(ladder_steps=args.lambda_steps,
                                 save_intermediate=args.hyp_search == 'GS',
                                 max_lambda_frac=args.max_lambda_frac,
                                 theta_0=lambda_args,
                                 pathwise_fit = args.ssl_pathwise_fit)
            elif args.model == 'LASSO':
                m.pathwise_fit(path_steps=args.lambda_steps,
                               save_intermediate=args.hyp_search == 'GS',
                               max_lambda_frac=args.max_lambda_frac,
                               theta_0=lambda_args)

            break
        except OptimizationDivergence as e:
            print(e)
            print("> Retrying again with stronger regularization...")
            if m.lambda_min == 0:
                m.lambda_min = 1e-3
            else:
                m.lambda_min *= 2.

    fit_end_time = time.time()

    result_dict['ProfilerMetrics']['Fit_time'] = round(fit_end_time - fit_start_time, 2)

    # ----------------------------------------------------------

    # Perform fitting and validation if gridsearch
    if args.hyp_search == 'GS':

        valid_start_time = time.time()
        m.select_best_model(data_dict['valid'], criterion=args.grid_metric)
        valid_end_time = time.time()
        result_dict['ProfilerMetrics']['Validation_time'] = round(valid_end_time - valid_start_time, 2)

        # TODO: Add validation metrics
        # result_dict['Validation'] = valid_table

    result_dict['Model parameters'] = m.to_table()
    result_dict['Hyperparameters'] = m.to_hyperparameter_table()

    # ----------------------------------------------------------
    # Clean up:

    data_dict['train'].cleanup()
    del data_dict['train']
    if 'valid' in data_dict:
        data_dict['valid'].cleanup()
        del data_dict['valid']

    # ----------------------------------------------------------

    return result_dict


def main():
    import argparse

    print("-->-->-->-->->->->->-------<-<-<--<-<--<-<--<--")
    print("""
    ▗▄▄▖  ▗▄▄▄▖ ▗▖  ▗▖ ▗▄▄▖  ▗▄▄▖   ▗▄▄▖
    ▐▌ ▐▌ ▐▌    ▐▛▚▖▐▌ ▐▌ ▐▌ ▐▌ ▐▌ ▐▌
    ▐▛▀▘  ▐▛▀▀▘ ▐▌ ▝▜▌ ▐▛▀▘  ▐▛▀▚▖  ▝▀▚▖
    ▐▌    ▐▙▄▄▖ ▐▌  ▐▌ ▐▌    ▐▌ ▐▌ ▗▄▄▞▘
    """)
    print("Penalized Regression for Polygenic Risk Scores")
    print(f"  Version: 0.0.1 | Release date: January 2025")
    print("      Authors: Shadi Zabad & Jack Song")
    print("             McGill University")
    print("-->-->-->-->->->->->-------<-<-<--<-<--<-<--<--\n\n")

    parser = argparse.ArgumentParser(description="""
        Commandline arguments for fitting Penalized PRS methods to GWAS summary statistics
    """)

    # Required input/output parameters:
    parser.add_argument('-l', '--ld-panel', dest='ld_dir', type=str, required=True,
                        help='The path to the directory where the LD matrices are stored. '
                             'Can be a wildcard of the form ld/chr_*')
    parser.add_argument('-s', '--sumstats', dest='sumstats_path', type=str, required=True,
                        help='The summary statistics directory or file. Can be a '
                             'wildcard of the form sumstats/chr_*')
    parser.add_argument('-m', '--model', dest='model', default='SSL',
                        choices={'SSL', 'SSLAlpha', 'LASSO'},
                        help='The model to fit to the data. Options are "SSL", "SSLAlpha", or "Lasso".')
    parser.add_argument('--cold-start', dest='cold_start', action='store_true', default=False,
                        help='Fit penalized PRS models without warm-start.')

    # Output files / directories:
    parser.add_argument('--output-dir', dest='output_dir', type=str, required=True,
                        help='The output directory where to store the inference results.')
    parser.add_argument('--output-file-prefix', dest='output_prefix', type=str, default='',
                        help='A prefix to append to the names of the output files (optional).')
    parser.add_argument('--temp-dir', dest='temp_dir', type=str, default='temp',
                        help='The temporary directory where to store intermediate files.')

    # Optional input arguments:
    parser.add_argument('--sumstats-format', dest='sumstats_format',
                        type=str.lower, default='plink',
                        choices={'plink1.9', 'plink2', 'plink', 'cojo', 'magenpy',
                                 'fastgwa', 'ssf', 'gwas-ssf', 'gwascatalog', 'saige', 'custom'},
                        help='The format for the summary statistics file(s).')
    parser.add_argument('--custom-sumstats-mapper', dest='custom_sumstats_mapper', type=str,
                        help='A comma-separated string with column name mappings between the custom '
                             'summary statistics format and the standard format expected by magenpy. '
                             'Provide only mappings for column names that are different, in the form of:'
                             '--custom-sumstats-mapper rsid=SNP,eff_allele=A1,beta=BETA')
    parser.add_argument('--custom-sumstats-sep', dest='custom_sumstats_sep', type=str, default='\t',
                        help='The delimiter for the summary statistics file with custom format.')
    parser.add_argument('--gwas-sample-size', dest='gwas_sample_size', type=float,
                        help='The overall sample size for the GWAS study. This must be provided if the '
                             'sample size per-SNP is not in the summary statistics file.')

    # Arguments for the validation set (individual-level data):
    parser.add_argument('--validation-bfile', dest='validation_bed', type=str,
                        help='The BED files containing the genotype data for the validation set. '
                             'You may use a wildcard here (e.g. "data/chr_*.bed")')
    parser.add_argument('--validation-pheno', dest='validation_pheno', type=str,
                        help='A tab-separated file containing the phenotype for the validation set. '
                             'The expected format is: FID IID phenotype (no header)')
    parser.add_argument('--validation-keep', dest='validation_keep', type=str,
                        help='A plink-style keep file to select a subset of individuals for the validation set.')

    # Arguments for the validation set (summary statistics):
    parser.add_argument('--validation-ld-panel', dest='validation_ld_panel', type=str,
                        help='The path to the directory where the LD matrices for the validation set are stored. '
                             'Can be a wildcard of the form ld/chr_*')
    parser.add_argument('--validation-sumstats', dest='validation_sumstats_path', type=str,
                        help='The summary statistics directory or file for the validation set. Can be a '
                             'wildcard of the form sumstats/chr_*')
    parser.add_argument('--validation-sumstats-format', dest='validation_sumstats_format',
                        type=str.lower,
                        choices={'plink1.9', 'plink2', 'plink', 'cojo', 'magenpy',
                                 'fastgwa', 'ssf', 'gwas-ssf', 'gwascatalog', 'saige', 'custom'},
                        help='The format for the summary statistics file(s) for the validation set.')
    # For now, suppress the help message for these arguments:
    parser.add_argument('--validation-custom-sumstats-mapper',
                        dest='validation_custom_sumstats_mapper', type=str,
                        help=argparse.SUPPRESS)
    parser.add_argument('--validation-custom-sumstats-sep',
                        dest='validation_custom_sumstats_sep',
                        type=str, default='\t',
                        help=argparse.SUPPRESS)
    parser.add_argument('--validation-gwas-sample-size',
                        dest='validation_gwas_sample_size',
                        type=float,
                        help=argparse.SUPPRESS)

    # model options and hyperparameters
    parser.add_argument('--float-precision', dest='float_precision', type=str, default='float32',
                        help='The float precision to use when fitting the model.',
                        choices={'float32', 'float64'})
    parser.add_argument('--use-symmetric-ld', dest='use_symmetric_ld', action='store_true',
                        default=False,
                        help='Use the symmetric form of the LD matrix when fitting the model.')
    parser.add_argument('--dequantize-on-the-fly', dest='dequantize_on_the_fly', action='store_true',
                        default=False,
                        help='Dequantize the entries of the LD matrix on-the-fly during inference.')
    parser.add_argument('--fix-var', dest='fix_var', type=float, default=1.0,
                        help='Set the value of the residual variance hyperparameter, variance/sigma2, '
                             'to the provided value for the fixed variance case. '
                              ' (see also --unknown-var for the unknown variance case)')
    parser.add_argument('--unknown-var', dest='unknown_var', action='store_true',
                        help='Fit the SSL model across the warm-start ladder by estimating the assumed to '
                        'be unknown variance/sigma2.')
    parser.add_argument('--lambda-min', dest='lambda_min', type=str,
                        help='Set the value of the lambda_min parameter, which acts as a regularizer for '
                             'the effect sizes and compensates for noise in the LD matrix. Set to "infer" to '
                             'derive this parameter from the properties of the LD matrix itself.')
    parser.add_argument('--max-iter', dest='max_iter', type=int, default=1000,
                        help='The maximum number of iterations to run the coordinate ascent algorithm.')
    parser.add_argument('--alpha', dest='alpha', type=float, default=-0.25,   
                    help='Alpha exponent for per-SNP MAF scaling (only used by SSLAlpha).')

    # Arguments for Hyperparameter tuning / model initialization:
    parser.add_argument('--hyp-search', dest='hyp_search', type=str, default='WS',
                        choices={'WS', 'GS'},
                        help='The strategy for tuning the hyperparameters of the model. '
                             'Options are WS (Warm-start across ladder), GS (Grid search)')

    parser.add_argument('--grid-metric', dest='grid_metric', type=str, default='pseudo_validation',
                        help='The metric for selecting best performing model in grid search.',
                        choices={'pseudo_validation', 'pseudo_pearson_validation'})

    # Grid-related parameters:
    parser.add_argument("--lambda-steps", dest = "lambda_steps", type=int, default=20,
                        help='The number of steps (unique values) for the lambda penalty in '
                             'Lasso and SSL models.')
    parser.add_argument("--min-lambda-frac", dest = "min_lambda_frac", type=float, default=1e-3,
                        help='Minimum lambda penalty as a fraction of lambda_max.')
    parser.add_argument("--max-lambda-frac", dest = "max_lambda_frac", type=float, default=.99,
                        help='Maximum lambda penalty as a fraction of lambda_max.')
    parser.add_argument("--ssl-pathwise-fit", dest = "ssl_pathwise_fit", action='store_true', default=False,
                        help='Fit the SSL model in a pathwise fit fashion similar to LASSO.')

    # Miscellaneous / general parameters:

    parser.add_argument('--genomewide', dest='genomewide', action='store_true', default=False,
                        help='Fit all chromosomes jointly')
    parser.add_argument('--exclude-lrld', dest='exclude_lrld', action='store_true', default=False,
                        help='Exclude Long Range LD (LRLD) regions during inference. These regions can cause '
                             'numerical instabilities in some cases.')
    parser.add_argument('--backend', dest='backend', type=str.lower, default='xarray',
                        choices={'xarray', 'plink'},
                        help='The backend software used for computations on the genotype matrix.')
    parser.add_argument('--n-jobs', dest='n_jobs', type=int, default=1,
                        help='The number of processes to launch for the hyperparameter search (default is '
                             '1, but we recommend increasing this depending on system capacity).')
    parser.add_argument('--threads', dest='threads', type=int, default=1,
                        help='The number of threads to use in the coordinate ascent step.')
    parser.add_argument('--output-profiler-metrics', dest='output_profiler_metrics', action='store_true',
                        default=False, help='Output the profiler metrics that measure runtime, memory usage, etc.')

    parser.add_argument('--seed', dest='seed', type=int, default=7209,
                        help='The random seed to use for the random number generator.')

    args = parser.parse_args()

    # ----------------------------------------------------------
    # Import required modules:

    import time
    from datetime import timedelta
    import pandas as pd
    import numpy as np
    from magenpy.utils.system_utils import makedir
    from joblib import Parallel, delayed
    from joblib.externals.loky import get_reusable_executor

    # ----------------------------------------------------------
    # Print the parsed arguments:

    print('{:-^62}\n'.format('  Parsed arguments  '))

    for key, val in vars(args).items():
        if val is not None and val != parser.get_default(key):
            print("--", key, ":", val)

    # # ----------------------------------------------------------

    # (1) Check the validity, consistency, and completeness of the commandline arguments:
    check_args(args)

    # Record start time:
    total_start_time = time.time()

    # (2) Read the data:
    data_loaders = init_data(args)

    # Record time for data preparation:
    data_prep_time = time.time()

    # (3) Prepare the model:
    model = prepare_model(args)

    # (4) Fit to data:
    print('\n{:-^62}\n'.format('  Inference  '))

    with PeakMemoryProfiler() as peak_mem:

        parallel = Parallel(n_jobs=args.n_jobs)

        with parallel:
            fit_results = parallel(
                delayed(fit_model)(model, dl, args)
                for idx, dl in enumerate(data_loaders)
            )

    # Record end time:
    total_end_time = time.time()

    # (5) Write the results to file:

    if len(fit_results) > 0:
        print('\n{:-^62}\n'.format(''))
        print("\n>>> Writing the inference results to:\n", args.output_dir)

        makedir(osp.abspath(args.output_dir))

        eff_tables = pd.concat([r['Model parameters'] for r in fit_results])
        hyp_tables = pd.concat([r['Hyperparameters'].assign(Chromosome=r['Chromosome']) for r in fit_results])

        try:
            valid_tables = pd.concat([r['Validation'].assign(Chromosome=r['Chromosome']) for r in fit_results])
        except KeyError:
            valid_tables = None
            pass

        profm_table = pd.concat([pd.DataFrame(r['ProfilerMetrics'], index=[0]).assign(Chromosome=r['Chromosome'])
                                 for r in fit_results])
        profm_table['Total_WallClockTime'] = round(total_end_time - total_start_time, 2)
        profm_table['DataPrep_Time'] = round(data_prep_time - total_start_time, 2)
        profm_table['Peak_Memory_MB'] = round(peak_mem.get_peak_memory(), 2)

        output_prefix = osp.join(args.output_dir, args.output_prefix + args.model + '_' + args.hyp_search)

        if len(eff_tables) > 0:
            eff_tables.to_csv(output_prefix + '.tsv', sep="\t", index=False)

        if len(hyp_tables) > 0:
             hyp_tables.to_csv(output_prefix + '.hyp', sep="\t", index=False)

        if valid_tables is not None:
            valid_tables.to_csv(output_prefix + '.validation', sep="\t", index=False)

        if args.output_profiler_metrics:
            profm_table.to_csv(output_prefix + '.prof', sep="\t", index=False)

    print('>>> Total Runtime:\n', timedelta(seconds=total_end_time - total_start_time))


if __name__ == '__main__':
    main()
