#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
#    py-ard
#    Copyright (c) 2023 Be The Match operated by National Marrow Donor Program. All Rights Reserved.
#
#    This library is free software; you can redistribute it and/or modify it
#    under the terms of the GNU Lesser General Public License as published
#    by the Free Software Foundation; either version 3 of the License, or (at
#    your option) any later version.
#
#    This library is distributed in the hope that it will be useful, but WITHOUT
#    ANY WARRANTY; with out even the implied warranty of MERCHANTABILITY or
#    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
#    License for more details.
#
#    You should have received a copy of the GNU Lesser General Public License
#    along with this library;  if not, write to the Free Software Foundation,
#    Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA.
#
#    > http://www.fsf.org/licensing/licenses/lgpl.html
#    > http://www.opensource.org/licenses/lgpl-license.php
#
#
#  Quick script to reduce alleles from a CSV file
#
#  Use configuration file from `--config` to setup configurations that's used here
#  For Excel output, openpyxl library needs to be installed.
#       pip install openpyxl
#
import argparse
import json
import re
import sys
from urllib.error import HTTPError

import pandas as pd

import pyard
from pyard.db import similar_alleles
import pyard.drbx as drbx
from pyard.exceptions import PyArdError, InvalidTypingError
from pyard.misc import get_data_dir, get_imgt_version, download_to_file


def is_serology(allele: str) -> bool:
    return ard.is_serology(allele)


def is_3field(allele: str) -> bool:
    return len(allele.split(":")) > 2


def is_2field(allele: str) -> bool:
    if "*" in allele:
        allele = allele.split("*")[1]
    fields = allele.split(":")
    # Check to see the 2 fields are digits to distinguish from MAC.
    return len(fields) == 2 and fields[0].isdigit() and fields[1].isdigit()


def is_P(allele: str) -> bool:
    if allele.endswith("P"):
        fields = allele.split(":")
        if len(fields) == 2:  # Ps are 2 fields
            # Check both fields are digits only
            # Eg: A*02:01P
            # Check last 2 digits of first field: 02 is numeric
            # Check digits of seconds field: 01 is numeric
            return fields[0][-2:].isdigit() and fields[1][:-1].isdigit()
    return False


def should_be_reduced(allele, locus_allele):
    if is_serology(allele):
        return ard_config["reduce_serology"]

    if ard_config["reduce_v2"]:
        if ard.is_v2(locus_allele):
            return True

    if ard_config["reduce_2field"]:
        if is_2field(locus_allele):
            return True

    if ard_config["reduce_3field"]:
        if is_3field(locus_allele):
            return True

    if ard_config["reduce_P"]:
        if is_P(allele):
            return True

    if ard_config["reduce_XX"]:
        if ard.is_XX(locus_allele):
            return True

    if ard_config["reduce_MAC"]:
        if ard.is_mac(locus_allele) and not ard.is_XX(locus_allele):
            return True

    return False


def remove_locus_name(reduced_allele):
    return "/".join(map(lambda a: a.split("*")[1], reduced_allele.split("/")))


def redux(allele, locus, column_name):
    # Does the allele name have the locus in it ?
    if allele == "":
        return allele
    if "*" in allele:
        locus_allele = allele
    elif ard_config.get("locus_in_allele_name"):
        locus_allele = allele
    else:
        if allele.startswith(locus):
            locus_allele = allele
        else:
            if ":" in allele:
                locus_allele = f"{locus}*{allele}"
            else:
                locus_allele = f"{locus}{allele}"  # serology

    # Check the config if this allele should be reduced
    if should_be_reduced(allele, locus_allele):
        # print(f"reducing '{locus_allele}'")
        try:
            reduced_allele = ard.redux(locus_allele, ard_config["redux_type"])
        except PyArdError as e:
            if verbose:
                print(e)
            message = f"Failed reducing '{locus_allele}' in column {column_name}"
            print(message)
            failed_to_reduce_alleles.append((column_name, locus_allele))
            return allele
        # print(f"reduced to '{reduced_allele}'")
        if reduced_allele:
            if ard_config.get("keep_locus_in_allele_name"):
                allele = reduced_allele
            else:
                allele = remove_locus_name(reduced_allele)
        else:
            if verbose:
                print(f"Failed to reduce {locus_allele}")
        if verbose:
            print(f"\t{locus_allele} => {allele}")
    else:
        if ard_config.get("convert_v2_to_v3"):
            if ard.is_v2(locus_allele):
                v3_allele = ard.v2_to_v3(locus_allele)
                if not ard_config.get("keep_locus_in_allele_name"):
                    allele = remove_locus_name(v3_allele)
                else:
                    allele = v3_allele
                if verbose:
                    print(f"\t{locus_allele} => {allele}")
        elif ard_config.get("keep_locus_in_allele_name"):
            allele = locus_allele

    return allele


def clean_locus(allele: str, locus: str, column_name: str = "Unknown") -> str:
    if allele:
        # Remove all white spaces
        allele = white_space_regex.sub("", allele)
        # If the allele comes in as an allele list, apply reduce to all alleles
        if "/" in allele:
            return "/".join([redux(a, locus, column_name) for a in allele.split("/")])
        else:
            return redux(allele, locus, column_name)
    return allele


def create_drbx(row, locus_in_allele_name):
    return drbx.map_drbx(row.values, locus_in_allele_name)


def reduce_locus_columns(df, ard_config, locus_column_mapping, verbose):
    reduce_prefix = ard_config.get("reduced_column_prefix", "reduced_")
    for subject in locus_column_mapping:
        for locus in locus_column_mapping[subject]:
            # Reduce each of the specified columns
            locus_columns = locus_column_mapping[subject][locus]
            for column in locus_columns:
                if verbose:
                    print(f"Column:{column} =>")
                if ard_config.get("new_column_for_redux"):
                    # insert a new column
                    new_column_name = f"{reduce_prefix}{column}"
                    new_column_index = df.columns.get_loc(column) + 1
                    # Apply clean_locus function to the column and insert as a new column
                    df.insert(
                        new_column_index,
                        new_column_name,
                        df[column].apply(
                            clean_locus, locus=locus.upper(), column_name=column
                        ),
                    )
                    locus_columns[locus_columns.index(column)] = new_column_name
                else:
                    # Apply clean_locus function to the column and replace the column
                    df[column] = df[column].apply(
                        clean_locus, locus=locus, column_name=column
                    )
    # Map DRB3,DRB4,DRB5 to DRBX if specified
    # New columns DRBX_1 and DRBX_2 are created
    if ard_config.get("map_drb345_to_drbx"):
        drbx_loci = ["DRB3", "DRB4", "DRB5"]
        drbx_columns = [
            col_name for col_name in df.columns if col_name.split("_")[1] in drbx_loci
        ]
        if len(drbx_columns) == len(drbx_loci) * 2:  # For Type1/Type2
            locus_in_allele_name = ard_config["keep_locus_in_allele_name"]
            df_drbx = df[drbx_columns].apply(
                create_drbx, axis=1, args=(locus_in_allele_name,)
            )
            df["DRBX_1"], df["DRBX_2"] = zip(*df_drbx)

    if ard_config.get("generate_glstring"):
        for subject in locus_column_mapping:
            slug_columns = []
            for locus in locus_column_mapping[subject]:
                slug_column = locus + "_slug"
                slug_columns.append(slug_column)
                if len(locus_column_mapping[subject][locus]) > 1:
                    df[slug_column] = (
                        df[locus_column_mapping[subject][locus][0]]
                        + "+"
                        + df[locus_column_mapping[subject][locus][1]]
                    )
                else:
                    df[slug_column] = df[locus_column_mapping[subject][locus][0]]

            df[subject + "_gl"] = df[slug_columns].agg("^".join, axis=1)
            df[subject + "_gl"] = df[subject + "_gl"].apply(
                lambda gl: gl.replace("^+", "")
            )
            df.drop(columns=slug_columns, inplace=True)


def reduce_glstring(glstring: str) -> str:
    try:
        return ard.redux(glstring, ard_config["redux_type"])
    except InvalidTypingError as e:
        print(f"Error reducing {glstring} \n", e.message, file=sys.stderr)
        return "Failed"


def reduce_glstring_columns(df, ard_config, glstring_columns):
    reduce_prefix = ard_config.get("reduced_column_prefix", "reduced_")
    for column in glstring_columns:
        if ard_config.get("new_column_for_redux"):
            # insert a new column
            new_column_name = f"{reduce_prefix}{column}"
            new_column_index = df.columns.get_loc(column) + 1
            # Apply clean_locus function to the column and insert as a new column
            df.insert(
                new_column_index, new_column_name, df[column].apply(reduce_glstring)
            )
        else:
            # Apply clean_locus function to the column and replace the column
            df[column] = df[column].apply(reduce_glstring)


if __name__ == "__main__":
    # config is specified with a -c parameter
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config", help="JSON Configuration file")
    parser.add_argument(
        "-d",
        "--data-dir",
        dest="data_dir",
        help="Data directory to store imported data",
    )
    parser.add_argument(
        "-i",
        "--imgt-version",
        dest="imgt_version",
        help="IPD-IMGT/HLA db to use for redux",
    )
    parser.add_argument(
        "-q",
        "--quiet",
        dest="quiet",
        action="store_true",
        default=False,
        help="Don't print verbose log",
    )
    parser.add_argument(
        "-g",
        "--generate-sample",
        dest="generate",
        action="store_true",
        default=False,
        help="Generate sample config file and csv file",
    )

    args = parser.parse_args()

    if args.generate:
        sample_files = [
            "reduce_conf.json",
            "sample.csv",
            "reduce_conf_glstring.json",
            "sample_glstring.csv",
        ]
        for sample_file in sample_files:
            try:
                url = f"https://raw.githubusercontent.com/nmdp-bioinformatics/py-ard/master/extras/{sample_file}"
                download_to_file(url, sample_file)
                print(f"Created {sample_file}")
            except HTTPError:
                print(f"Download failed for {sample_file}")
        sys.exit(0)

    config_filename = args.config
    if not config_filename:
        print("Config file required. Specify with -c/--config")
        sys.exit(1)

    print("Using config file:", config_filename)
    with open(config_filename) as conf_file:
        ard_config = json.load(conf_file)

    if not args.quiet:
        verbose = ard_config.get("verbose_log")
    else:
        verbose = False

    white_space_regex = re.compile(r"\s+")

    if ard_config.get("output_file_format") == "xlsx":
        try:
            import openpyxl
        except ImportError:
            print(
                "For Excel output, openpyxl library needs to be installed. "
                "Install with:"
            )
            print("  pip install openpyxl")
            sys.exit(1)

    data_dir = get_data_dir(args.data_dir)
    imgt_version = get_imgt_version(args.imgt_version)
    max_cache_size = ard_config.get("redux_cache_size", pyard.DEFAULT_CACHE_SIZE)
    csv_redux_config = {
        "reduce_serology": ard_config.get("reduce_serology", True),
        "reduce_v2": ard_config.get("reduce_v2", True),
        "reduce_3field": ard_config.get("reduce_3field", True),
        "reduce_P": ard_config.get("reduce_P", True),
        "reduce_XX": ard_config.get("reduce_XX", True),
        "reduce_MAC": ard_config.get("reduce_MAC", True),
        "map_drb345_to_drbx": ard_config.get("map_drb345_to_drbx", True),
        "verbose_log": ard_config.get("verbose_log", True),
    }
    ard = pyard.init(
        imgt_version=imgt_version,
        data_dir=data_dir,
        cache_size=max_cache_size,
        config=csv_redux_config,
    )

    # Read the Input File
    # Read only the columns to be saved.
    # Header is the first row
    # Don't convert to NAs
    try:
        df = pd.read_csv(
            ard_config["in_csv_filename"],
            usecols=ard_config["columns_from_csv"],
            header=0,
            dtype=str,
            keep_default_na=False,
        )
    except FileNotFoundError as e:
        print(f"File not found {ard_config.get('in_csv_filename')}", file=sys.stderr)
        sys.exit(1)

    failed_to_reduce_alleles = []
    locus_column_mapping = ard_config.get("locus_column_mapping", None)
    if locus_column_mapping:
        reduce_locus_columns(df, ard_config, locus_column_mapping, verbose)

    glstring_columns = ard_config.get("glstring_columns", None)
    if glstring_columns:
        reduce_glstring_columns(df, ard_config, glstring_columns)

    # Save as XLSX if specified
    if ard_config["output_file_format"] == "xlsx":
        out_file_name = f"{ard_config['out_csv_filename']}.xlsx"
        df.to_excel(out_file_name, index=False)
    else:
        # Save as compressed CSV
        out_file_name = f"{ard_config['out_csv_filename'] + '.gz' if ard_config['apply_compression'] else ''}"
        df.to_csv(
            out_file_name, index=False, compression=ard_config["apply_compression"]
        )

    if len(failed_to_reduce_alleles) == 0:
        print("No Errors", file=sys.stderr)
    else:
        print("Summary", file=sys.stderr)
        print("-------", file=sys.stderr)
        print(
            f"{len(failed_to_reduce_alleles)} alleles failed to reduce.",
            file=sys.stderr,
        )
        print(
            "| Column  Name    |      Allele      |      Did you mean ?       ",
            file=sys.stderr,
        )
        print(
            "| --------------- | ---------------- | ------------------------- ",
            file=sys.stderr,
        )
        for column_name, locus_allele in failed_to_reduce_alleles:
            similar_allele_names = similar_alleles(ard.db_connection, locus_allele)
            if similar_allele_names:
                similar_allele_names = ",".join(
                    sorted(similar_allele_names, reverse=True)
                )
            else:
                similar_allele_names = "NA"
            print(
                f"| {column_name:15} | {locus_allele:16} | {similar_allele_names} ",
                file=sys.stderr,
            )
    # Done
    print(f"Saved result to file:{out_file_name}")
