#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
#    py-ard
#    Copyright (c) 2020 Be The Match operated by National Marrow Donor Program. All Rights Reserved.
#
#    This library is free software; you can redistribute it and/or modify it
#    under the terms of the GNU Lesser General Public License as published
#    by the Free Software Foundation; either version 3 of the License, or (at
#    your option) any later version.
#
#    This library is distributed in the hope that it will be useful, but WITHOUT
#    ANY WARRANTY; with out even the implied warranty of MERCHANTABILITY or
#    FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
#    License for more details.
#
#    You should have received a copy of the GNU Lesser General Public License
#    along with this library;  if not, write to the Free Software Foundation,
#    Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA.
#
#    > http://www.fsf.org/licensing/licenses/lgpl.html
#    > http://www.opensource.org/licenses/lgpl-license.php
#
#
#  Quick script to reduce alleles from a CSV file
#
#  Use configuration file from `--config` to setup configurations that's used here
#  For Excel output, openpyxl library needs to be installed.
#       pip install openpyxl
#
import argparse
import json
import re
import sys

import pandas as pd

import pyard
import pyard.drbx as drbx
from pyard.exceptions import PyArdError


def is_serology(allele: str) -> bool:
    return ard.is_serology(allele)


def is_3field(allele: str) -> bool:
    return len(allele.split(':')) > 2


def is_2field(allele: str) -> bool:
    return len(allele.split(':')) == 2


def is_P(allele: str) -> bool:
    if allele.endswith('P'):
        fields = allele.split(':')
        if len(fields) == 2:  # Ps are 2 fields
            # Check both fields are digits only
            # Eg: A*02:01P
            # Check last 2 digits of first field: 02 is numeric
            # Check digits of seconds field: 01 is numeric
            return fields[0][-2:].isdigit() and fields[1][:-1].isdigit()
    return False


def should_be_reduced(allele, locus_allele):
    if is_serology(allele):
        return ard_config["reduce_serology"]

    if ard_config["reduce_v2"]:
        if ard.is_v2(locus_allele):
            return True

    if ard_config["reduce_2field"]:
        if is_2field(locus_allele):
            return True

    if ard_config["reduce_3field"]:
        if is_3field(locus_allele):
            return True

    if ard_config["reduce_P"]:
        if is_P(allele):
            return True

    if ard_config["reduce_XX"]:
        if ard.is_XX(locus_allele):
            return True

    if ard_config["reduce_MAC"]:
        if ard.is_mac(locus_allele) and not ard.is_XX(locus_allele):
            return True

    return False


def remove_locus_name(reduced_allele):
    return "/".join(map(lambda a: a.split('*')[1],
                        reduced_allele.split('/')))


def reduce(allele, locus, column_name):
    # Does the allele name have the locus in it ?
    if allele == '':
        return allele
    if '*' in allele:
        locus_allele = allele
    elif ard_config["locus_in_allele_name"]:
        locus_allele = allele
    else:
        locus_allele = f"{locus}*{allele}"

    # Check the config if this allele should be reduced
    if should_be_reduced(allele, locus_allele):
        # print(f"reducing '{locus_allele}'")
        try:
            reduced_allele = ard.redux_gl(locus_allele, ard_config["redux_type"])
        except PyArdError as e:
            if verbose:
                print(e)
            message = f"Failed reducing '{locus_allele}' in column {column_name}"
            print(message)
            failure_summary_messages.append(message)
            return allele
        # print(f"reduced to '{reduced_allele}'")
        if reduced_allele:
            if ard_config["keep_locus_in_allele_name"]:
                allele = reduced_allele
            else:
                allele = remove_locus_name(reduced_allele)
        else:
            if verbose:
                print(f"Failed to reduce {locus_allele}")
        if verbose:
            print(f"\t{locus_allele} => {allele}")
    else:
        if ard_config['convert_v2_to_v3']:
            if ard.is_v2(locus_allele):
                v3_allele = ard.v2_to_v3(locus_allele)
                if not ard_config["keep_locus_in_allele_name"]:
                    allele = remove_locus_name(v3_allele)
                else:
                    allele = v3_allele
                if verbose:
                    print(f"\t{locus_allele} => {allele}")
        elif ard_config["keep_locus_in_allele_name"]:
            allele = locus_allele

    return allele


def clean_locus(allele: str, column_name: str = 'Unknown') -> str:
    if allele:
        # Remove all white spaces
        allele = white_space_regex.sub('', allele)
        locus = column_name.split('_')[1].upper()
        # If the allele comes in as an allele list, apply reduce to all alleles
        if '/' in allele:
            return "/".join(map(reduce, allele.split('/'), locus, column_name))
        else:
            return reduce(allele, locus, column_name)
    return allele


def create_drbx(row, locus_in_allele_name):
    return drbx.map_drbx(row.values, locus_in_allele_name)


if __name__ == '__main__':

    # config is specified with a -c parameter
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config',
                        help="JSON Configuration file",
                        required=True)

    args = parser.parse_args()
    config_filename = args.config

    print("Using config file:", config_filename)
    with open(config_filename) as conf_file:
        ard_config = json.load(conf_file)

    verbose = ard_config["verbose_log"]
    white_space_regex = re.compile(r"\s+")

    if ard_config["output_file_format"] == 'xlsx':
        try:
            import openpyxl
        except ImportError:
            print("For Excel output, openpyxl library needs to be installed. "
                  "Install with:")
            print("  pip install openpyxl")
            sys.exit(1)

    # Instantiate py-ard object with the latest
    ard = pyard.ARD()

    # Read the Input File
    # Read only the columns to be saved.
    # Header is the first row
    # Don't convert to NAs
    df = pd.read_csv(ard_config["in_csv_filename"],
                     usecols=ard_config["columns_from_csv"],
                     header=0, dtype=str,
                     keep_default_na=False)

    failure_summary_messages = []
    # Reduce each of the specified columns
    for column in ard_config["columns_to_reduce_in_csv"]:
        if verbose:
            print(f"Column:{column} =>")
        if ard_config["new_column_for_redux"]:
            # insert a new column
            new_column_name = f"reduced_{column}"
            new_column_index = df.columns.get_loc(column) + 1
            # Apply clean_locus function to the column and insert as a new column
            df.insert(new_column_index, new_column_name,
                      df[column].apply(clean_locus, column_name=column))
        else:
            # Apply clean_locus function to the column and replace the column
            df[column] = df[column].apply(clean_locus, column_name=column)

    # Map DRB3,DRB4,DRB5 to DRBX if specified
    # New columns DRBX_1 and DRBX_2 are created
    if ard_config['map_drb345_to_drbx']:
        drbx_loci = ['DRB3', 'DRB4', 'DRB5']
        drbx_columns = [col_name for col_name in df.columns if col_name.split('_')[1] in drbx_loci]
        if len(drbx_columns) == len(drbx_loci) * 2:  # For Type1/Type2
            locus_in_allele_name = ard_config["keep_locus_in_allele_name"]
            df_drbx = df[drbx_columns].apply(create_drbx, axis=1, args=(locus_in_allele_name,))
            df['DRBX_1'], df['DRBX_2'] = zip(*df_drbx)

    # Save as XLSX if specified
    if ard_config["output_file_format"] == 'xlsx':
        out_file_name = f"{ard_config['out_csv_filename']}.xlsx"
        df.to_excel(out_file_name, index=False)
    else:
        # Save as compressed CSV
        out_file_name = f"{ard_config['out_csv_filename'] + '.gz' if ard_config['apply_compression'] else ''}"
        df.to_csv(out_file_name, index=False, compression=ard_config["apply_compression"])

    if len(failure_summary_messages) == 0:
        print("No Errors", file=sys.stderr)
    else:
        print("Summary", file=sys.stderr)
        print("-------", file=sys.stderr)
        for message in failure_summary_messages:
            print("\t", message, file=sys.stderr)
    # Done
    print(f"Saved result to file:{out_file_name}")
